Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
5d14a87
dump
zucchini-nlp Jun 4, 2025
5c5825b
push other models
zucchini-nlp Jun 5, 2025
051fe7f
fix simple greedy generation
zucchini-nlp Jun 5, 2025
b04ddbc
xmod
zucchini-nlp Jun 5, 2025
6a289a7
add fmst and clean up some mentions of old cache format
zucchini-nlp Jun 6, 2025
b3be72b
gpt-bigcode now follows standards
zucchini-nlp Jun 9, 2025
85061bc
delete tuple cache reference in generation
zucchini-nlp Jun 9, 2025
1424600
fix some models
zucchini-nlp Jun 9, 2025
02fb0d2
fix some models
zucchini-nlp Jun 9, 2025
f7494bc
fix mambas and support cache in tapas
zucchini-nlp Jun 9, 2025
576fb7b
fix some more tests
zucchini-nlp Jun 10, 2025
8757e84
fix copies
zucchini-nlp Jun 10, 2025
bcf0cc7
delete `_reorder_cache`
zucchini-nlp Jun 10, 2025
91d92f1
another fix copies
zucchini-nlp Jun 10, 2025
edf5f6e
fix typos and delete unnecessary test
zucchini-nlp Jun 10, 2025
b236e90
fix rag generate, needs special cache reordering
zucchini-nlp Jun 10, 2025
1893f8a
fix tapas and superglue
zucchini-nlp Jun 10, 2025
46e50b5
reformer create special cache
zucchini-nlp Jun 10, 2025
204ed55
recurrent gemma `reorder_cache` was a no-op, delete
zucchini-nlp Jun 10, 2025
7b61dfd
fix-copies
zucchini-nlp Jun 10, 2025
69c20ae
fix blio and musicgen pipeline tests
zucchini-nlp Jun 10, 2025
d281a6c
Merge branch 'main' into cache-class-finalize
zucchini-nlp Jun 10, 2025
b508814
fix reformer
zucchini-nlp Jun 11, 2025
b7deae6
fix reformer, again...
zucchini-nlp Jun 11, 2025
ae88ecc
delete `_supports_cache_class`
zucchini-nlp Jun 11, 2025
f1ec0ba
delete `supports_quantized_cache`
zucchini-nlp Jun 11, 2025
8f5d8a0
fix failing tests
zucchini-nlp Jun 11, 2025
08ad1b0
fix copies
zucchini-nlp Jun 12, 2025
dfdf50b
some minor clean up
zucchini-nlp Jun 12, 2025
e9a281f
style
zucchini-nlp Jun 12, 2025
e735b87
merge main, so many conflicts
zucchini-nlp Jul 1, 2025
e1a3fc4
style
zucchini-nlp Jul 1, 2025
3190a9e
fix copies
zucchini-nlp Jul 2, 2025
2f942f8
fix tests
zucchini-nlp Jul 7, 2025
0fc0159
merge main
zucchini-nlp Jul 7, 2025
ccdd784
fix copies
zucchini-nlp Jul 7, 2025
63f1bd3
create causal mask now needs positions?
zucchini-nlp Jul 7, 2025
2dc3b01
fixc copies
zucchini-nlp Jul 7, 2025
f050810
merge main
zucchini-nlp Jul 10, 2025
e945d2f
style
zucchini-nlp Jul 10, 2025
8a7a05b
Update tests/test_modeling_common.py
zucchini-nlp Jul 10, 2025
ce665c0
clean-up of non-generative model after merging main
zucchini-nlp Jul 10, 2025
d0f68d0
check `is_decoder` for cache
zucchini-nlp Jul 10, 2025
72bc51a
delete transpose for scores
zucchini-nlp Jul 10, 2025
ac73959
remove tuple cache from docs everywhere
zucchini-nlp Jul 10, 2025
fd84e67
fix tests
zucchini-nlp Jul 10, 2025
7c2d5b6
fix copies
zucchini-nlp Jul 10, 2025
122564e
fix copies once more
zucchini-nlp Jul 10, 2025
b265287
properly deprecate `encoder_attention_mask` in Bert-like models
zucchini-nlp Jul 10, 2025
8218c5b
import `deprecate_kwarg` where needed
zucchini-nlp Jul 10, 2025
bb1866c
fix copies again
zucchini-nlp Jul 10, 2025
b9fe72c
Merge branch 'main' into cache-class-finalize
zucchini-nlp Jul 14, 2025
b67a4c3
fix copies
zucchini-nlp Jul 14, 2025
5eeeeb3
delete `nex_decoder_cache`
zucchini-nlp Jul 15, 2025
5231ed5
fix copies asks to update for PLM
zucchini-nlp Jul 15, 2025
27e9539
merge main
zucchini-nlp Jul 15, 2025
5a74509
fix copies
zucchini-nlp Jul 15, 2025
011ee19
rebasing had a few new models, fix them and merge asap!
zucchini-nlp Jul 15, 2025
91f8072
fix copies once more
zucchini-nlp Jul 15, 2025
ec311c3
fix slow tests
zucchini-nlp Jul 16, 2025
78d44b4
Merge branch 'main' into cache-class-finalize
zucchini-nlp Jul 16, 2025
f9592f6
fix tests and updare PLM checkpoint
zucchini-nlp Jul 16, 2025
abccbee
add read token and revert accidentally removed line
zucchini-nlp Jul 16, 2025
327b9e1
oh com -on, style
zucchini-nlp Jul 16, 2025
651febb
just skip it, read token has no access to PLM yet
zucchini-nlp Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/cache_explanation.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ The legacy format is essentially the same data structure but organized different
- The tensors have the same shape `[batch_size, num_heads, seq_len, head_dim]`.
- The format is less flexible and doesn't support features like quantization or offloading.

If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~DynamicCache.from_legacy_cache`] and [`DynamicCache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format.
If your project depends on this legacy format, we recommend to convert to [`DynamicCache`] with [`~DynamicCache.from_legacy_cache`]. Note that legacy cache format is deprecated and not used anymore in `Transformers`. You can convert back to tuple format with [`DynamicCache.to_legacy_cache`] functions, which is helpful if you have custom logic for manipulating a cache in a specific format.

```py
import torch
Expand Down
46 changes: 1 addition & 45 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
if is_sklearn_available():
from sklearn.metrics import roc_curve

from ..cache_utils import Cache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor

Expand Down Expand Up @@ -295,9 +294,7 @@ def _update_past_and_masks(
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens
)
self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens)
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
)
Expand Down Expand Up @@ -1180,47 +1177,6 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
return candidate_ids, candidate_logits


def _crop_past_key_values(model, past_key_values, max_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
if isinstance(past_key_values, Cache):
past_key_values.crop(max_length)
elif model.config.is_encoder_decoder:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :max_length, :],
past_key_values[idx][1][:, :, :max_length, :],
past_key_values[idx][2],
past_key_values[idx][3],
)
)
past_key_values = tuple(new_past)
# gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :max_length, :]
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
elif past_key_values is not None:
for idx in range(len(past_key_values)):
if past_key_values[idx] != ([], []):
new_past.append(
(
past_key_values[idx][0][:, :, :max_length, :],
past_key_values[idx][1][:, :, :max_length, :],
)
)
else:
new_past.append((past_key_values[idx][0], past_key_values[idx][1]))
past_key_values = tuple(new_past)
return past_key_values


def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_encoder_decoder: bool) -> dict[str, Any]:
"""Expands or crops the model's mask for decoding purposes, to the defined length"""

Expand Down
110 changes: 35 additions & 75 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
UniversalSpeculativeDecodingGenerator,
_crop_past_key_values,
_prepare_attention_mask,
_prepare_token_type_ids,
)
Expand Down Expand Up @@ -567,15 +566,7 @@ def prepare_inputs_for_generation(

# 1. Handle BC:
model_inputs = {}
# - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
if self._supports_cache_class:
model_inputs["cache_position"] = cache_position
# - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
# function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
# (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
elif cache_position is None:
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
model_inputs["cache_position"] = cache_position

# 2. Generic cache-dependent input preparation
if past_key_values is not None:
Expand Down Expand Up @@ -1014,12 +1005,6 @@ def _update_model_kwargs_for_generation(
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
return model_kwargs

def _reorder_cache(self, past_key_values, beam_idx):
raise NotImplementedError(
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
f" enable beam search for {self.__class__}"
)

def _get_candidate_generator(
self,
generation_config: GenerationConfig,
Expand Down Expand Up @@ -1559,13 +1544,6 @@ def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):

def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# If a `Cache` instance is passed, checks whether the model is compatible with it
if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
raise ValueError(
f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
"check the model documentation for supported cache formats."
)

# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
Expand Down Expand Up @@ -1975,21 +1953,23 @@ def _get_cache(
self._cache.reset()
return self._cache

def _supports_default_dynamic_cache(self) -> bool:
@classmethod
def _supports_default_dynamic_cache(cls) -> bool:
"""
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which
uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in
order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed
for `HybridMambaAttentionDynamicCache`).
This adds exception for some models like `Jamba` model which uses its own `HybridMambaAttentionDynamicCache`
and do not need to initialize the Cache in advance in order to save memory (because no back and forth
`to_legacy_cache` and `from_legacy_cache` will be performed for `HybridMambaAttentionDynamicCache`).
"""
return (
self._supports_cache_class
and "jamba" not in self.__class__.__name__.lower()
and "zamba" not in self.__class__.__name__.lower()
and "bamba" not in self.__class__.__name__.lower()
and "minimax" not in self.__class__.__name__.lower()
and "lfm2" not in self.__class__.__name__.lower()
# NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name
return not cls._is_stateful and all(
special_model_name not in cls.__name__.lower()
for special_model_name in [
"reformer",
"minimax",
"xlnet",
"lfm2",
]
)

def _prepare_cache_for_generation(
Expand Down Expand Up @@ -2076,7 +2056,7 @@ def _prepare_cache_for_generation(
model_kwargs=model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
if self.config.is_encoder_decoder or not self._supports_default_dynamic_cache():
raise ValueError(
"This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue and tag @zucchini-nlp."
Expand Down Expand Up @@ -3708,33 +3688,6 @@ def _sample(
else:
return input_ids

# Auxiliary functions for beam search
def _temporary_reorder_cache(self, past_key_values, beam_idx):
"""
Temporary function to handle the different types of cache reordering processes while we roll out `Cache`.

TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need
for this function, with `Cache.reorder_cache` being the sole remaining code path
"""
model_class = self.__class__.__name__.lower()
# Exception 1: code path for models using the legacy cache format
if isinstance(past_key_values, (tuple, list)):
past_key_values = self._reorder_cache(past_key_values, beam_idx)
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
# cache format is standardized, to avoid adding complexity to the codebase.
elif "gptbigcode" in model_class:
if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
"legacy tuple format or `DynamicCache`"
)
past_key_values = self._reorder_cache(past_key_values, beam_idx)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
# Standard code path: use the `Cache.reorder_cache`
else:
past_key_values.reorder_cache(beam_idx)
return past_key_values

@staticmethod
def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor:
"""[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]"""
Expand Down Expand Up @@ -4230,11 +4183,13 @@ def _beam_search(
# beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`)

# pluck the cache from the beam indices that will be used in the next iteration
# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
past_key_values=model_kwargs["past_key_values"],
beam_idx=self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]),
)
beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len])
if hasattr(self, "_reorder_cache"):
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

models like rag or reformer have their own special cache reorder logic, which I didn't remove. I don't think it is worth aligning these models with past_key_values.reorder_cache because they're pretty low usage

else:
model_kwargs["past_key_values"].reorder_cache(beam_idx)

cur_len = cur_len + 1
is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic(
Expand Down Expand Up @@ -4537,10 +4492,14 @@ def _group_beam_search(
# (that way the memory peak does not include outputs.logits)
del outputs

# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], reordering_indices
)
if hasattr(self, "_reorder_cache"):
model_kwargs["past_key_values"] = self._reorder_cache(
model_kwargs["past_key_values"], reordering_indices
)
else:
model_kwargs["past_key_values"].reorder_cache(reordering_indices)

# increase cur_len
cur_len = cur_len + 1
Expand Down Expand Up @@ -4774,10 +4733,12 @@ def _constrained_beam_search(
# (that way the memory peak does not include outputs.logits)
del outputs

# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
if hasattr(self, "_reorder_cache"):
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
else:
model_kwargs["past_key_values"].reorder_cache(beam_idx)

if return_dict_in_generate and output_scores:
beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))
Expand Down Expand Up @@ -5002,8 +4963,7 @@ def _assisted_decoding(
new_cur_len = input_ids.shape[1]

# 4.2. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
outputs.past_key_values.crop(new_cur_len - 1)

# 5. Update the candidate generation strategy if needed
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1971,13 +1971,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Flex Attention support
_supports_flex_attn = False

# Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
_supports_cache_class = False
# Has support `torch.compile(fullgraph=True)`
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to renaming this flag into _can_compile_fullgraph

_supports_static_cache = False

# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False

# A tensor parallel plan to be applied to the model when TP is enabled. For
# top-level models, this attribute is currently defined in respective model
# code. For base models, this attribute comes from
Expand Down
42 changes: 26 additions & 16 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,6 @@ def __init__(self, config: AlbertConfig):
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def prune_heads(self, heads: list[int]) -> None:
if len(heads) == 0:
return
Expand All @@ -302,13 +296,17 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)

query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
batch_size, seq_length, _ = hidden_states.shape
query_layer = self.query(hidden_states)
key_layer = self.key(hidden_states)
value_layer = self.value(hidden_states)
query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
1, 2
)
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
1, 2
)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
Expand Down Expand Up @@ -378,9 +376,21 @@ def forward(
return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)

batch_size, seq_len, _ = hidden_states.size()
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = (
self.query(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
key_layer = (
self.key(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
value_layer = (
self.value(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/arcee/modeling_arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,7 @@ class ArceePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True

_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs = {
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = False
_supports_sdpa = True
_supports_cache_class = True

_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": AriaTextDecoderLayer,
Expand Down Expand Up @@ -664,8 +664,6 @@ class AriaPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = True
_can_record_outputs = {
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = False
_supports_sdpa = True
_supports_cache_class = True

_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": AriaTextDecoderLayer,
Expand Down
Loading