Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9e1cb31
feat(rebase): transformers: bump to 4.57.3 with cache/kv compatibility
vbaddi Mar 17, 2026
217aaa1
nit: update qwen25 and move the resolve_kv_seq to modeling utils
vbaddi Mar 17, 2026
7e675d9
nit: move imports for resolve_kv_seq_len modeling_utils to _utils
vbaddi Mar 17, 2026
09445ae
nit: rebase to mainline and fix tests, disable gptoss w/subfunction
vbaddi Mar 17, 2026
0a37e98
test(subfunctions): validate decoder-block subfunction count and remo…
vbaddi Mar 25, 2026
eac98d0
Added few changes
abhishek-singh591 Mar 25, 2026
323f40d
Added few changes
abhishek-singh591 Mar 25, 2026
02cdf56
simplified qwen2 modeling file
abhishek-singh591 Mar 25, 2026
5269c30
simplified gemma2,granitemoe,qwen2.5 modeling file
abhishek-singh591 Mar 25, 2026
e56015c
simplified gemma2,granitemoe,qwen2.5 modeling file
abhishek-singh591 Mar 25, 2026
22351f7
Modified llama shiftkv modeling file
abhishek-singh591 Mar 26, 2026
2cfd44e
lint
abhishek-singh591 Mar 26, 2026
8cfa10a
Fix for quantizer error
qcdipankar Mar 26, 2026
27e069e
Changed past_key_value to past_key_values
abhishek-singh591 Mar 26, 2026
7029ad3
Merge branch 'main' into feat/rebase_transformers_unify_subfunctions
qcdipankar Mar 27, 2026
678b671
Fix for granite and skipped whisper test from CI
abhishek-singh591 Mar 27, 2026
283cf88
Update conftest.py
abhishek-singh591 Mar 27, 2026
447b5d2
Merge branch 'main' into feat/rebase_transformers_unify_subfunctions
qcdipankar Mar 30, 2026
bc149b6
Merge branch 'quic:main' into feat/rebase_transformers_unify_subfunct…
qcdipankar Mar 30, 2026
35817ce
Fix for whisper and adding modelling auto changes for fp8 qauntizers
qcdipankar Mar 30, 2026
ea4162b
fix fp8 llama model loading
mamtsing Mar 30, 2026
9b032ef
Fix for diaggregate serving
qcdipankar Mar 30, 2026
500fc83
Updated T5 modeling, randomness issue in diffuser tests
tv-karthikeya Mar 31, 2026
2f5d920
Skip qaic feature test for sampler to pass CI
qcdipankar Apr 1, 2026
f17adf1
EOS Added
qcdipankar Apr 1, 2026
e43e31e
Merge branch 'main' into feat/rebase_transformers_unify_subfunctions
qcdipankar Apr 1, 2026
5d86d39
Merge branch 'main' into feat/rebase_transformers_unify_subfunctions
qcdipankar Apr 2, 2026
53f5512
Merge branch 'main' into feat/rebase_transformers_unify_subfunctions
abhishek-singh591 Apr 3, 2026
11cb7bd
Update Jenkinsfile
abhishek-singh591 Apr 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 105 additions & 15 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers.cache_utils import DynamicCache, DynamicLayer, EncoderDecoderCache, HybridCache, HybridChunkedCache
from transformers.cache_utils import Cache, CacheLayerMixin, EncoderDecoderCache, HybridCache, HybridChunkedCache

from QEfficient.customop import (
CtxGatherFunc,
Expand Down Expand Up @@ -54,7 +54,47 @@ def _get_invalid_idx_value(cls):
return 0


class QEffDynamicLayer(DynamicLayer):
class QEffDynamicLayer(CacheLayerMixin):
is_sliding = False

def __init__(self):
super().__init__()

def lazy_initialization(self, key_states: torch.Tensor):
self.dtype = key_states.dtype
self.device = key_states.device
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
self.is_initialized = True

def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
kv_offset = 0
query_length = cache_position.shape[0]
kv_length = self.get_seq_length() + query_length
return kv_length, kv_offset

def get_seq_length(self) -> int:
if self.keys is None or self.keys.numel() == 0:
return 0
return self.keys.shape[-2]

def get_max_cache_shape(self) -> int:
return -1

@classmethod
def from_tensors(cls, key_states: torch.Tensor, value_states: torch.Tensor) -> "QEffDynamicLayer":
layer = cls()
layer.keys = key_states
layer.values = value_states
layer._mark_initialized(key_states)
return layer

def _mark_initialized(self, reference_states: torch.Tensor) -> None:
if not self.is_initialized:
self.dtype = reference_states.dtype
self.device = reference_states.device
self.is_initialized = True

def read_only(self, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer.
Expand All @@ -68,6 +108,8 @@ def read_only(self, cache_kwargs):
"""
# Gather
k_out, v_out = self.keys, self.values
if k_out is not None:
self._mark_initialized(k_out)
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
ctx_len = cache_kwargs.get("CCL", k_out.shape[2])
Expand Down Expand Up @@ -109,6 +151,8 @@ def read_only_blockedKV(self, start_index, end_index, cache_kwargs):
"""
# Gather
k_out, v_out = self.keys, self.values
if k_out is not None:
self._mark_initialized(k_out)
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
batch, num_kv_heads, _, _ = k_out.shape
Expand Down Expand Up @@ -150,7 +194,9 @@ def write_only(self, key_states, value_states, cache_kwargs):
if self.keys is None:
self.keys = key_states
self.values = value_states
self._mark_initialized(self.keys)
else:
self._mark_initialized(self.keys)
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs

Expand Down Expand Up @@ -189,8 +235,10 @@ def update(
if self.keys is None:
self.keys = key_states
self.values = value_states
self._mark_initialized(self.keys)
k_out, v_out = self.keys, self.values
else:
self._mark_initialized(self.keys)
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs

Expand Down Expand Up @@ -252,8 +300,10 @@ def update3D(
if self.keys is None:
self.keys = key_states
self.values = value_states
self._mark_initialized(self.keys)
k_out, v_out = self.keys, self.values
else:
self._mark_initialized(self.keys)
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)

Expand Down Expand Up @@ -293,7 +343,7 @@ def update3D(
return k_out, v_out


class QEffDynamicCache(DynamicCache):
class QEffDynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.

Expand All @@ -307,15 +357,46 @@ class QEffDynamicCache(DynamicCache):
"""

def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
# Remove layer_classes if present to avoid duplicate argument
# Remove cache-layer construction args if present to avoid duplicate arguments.
kwargs.pop("layer_classes", None)
from transformers.cache_utils import Cache # Import here to avoid circular import

Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs)
kwargs.pop("layers", None)
kwargs.pop("layer_class_to_replicate", None)

try:
# transformers>=4.57
Cache.__init__(self, *args, layer_class_to_replicate=QEffDynamicLayer, **kwargs)
except TypeError:
# transformers<=4.56
Cache.__init__(self, *args, layer_classes=QEffDynamicLayer, **kwargs)
Comment on lines +365 to +370
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to check version explicitly than try-except

if ddp_cache_data is not None:
for key_states, value_states in ddp_cache_data:
self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states))

def append_new_layers(self, layer_idx: int) -> None:
while len(self.layers) <= layer_idx:
self.layers.append(QEffDynamicLayer())

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "QEffDynamicCache":
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
legacy_cache = ()
for layer in self.layers:
legacy_cache += ((layer.keys, layer.values),)
return legacy_cache

def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int:
"""
Keep backward-compatible call shape while deferring to upstream implementation.
"""
return super().get_seq_length(layer_idx)

def read_only(self, layer_idx, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer `layer_idx`.
Expand Down Expand Up @@ -405,10 +486,7 @@ def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(
self_attention_cache=QEffDynamicCache(),
cross_attention_cache=QEffDynamicCache(),
)
cache = cls(QEffDynamicCache(), QEffDynamicCache())
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
Expand All @@ -419,6 +497,18 @@ def from_legacy_cache(
cache.is_updated[layer_idx] = True
return cache

def to_legacy_cache(self):
self_attn_legacy = self.self_attention_cache.to_legacy_cache()
cross_attn_legacy = self.cross_attention_cache.to_legacy_cache()

legacy_cache = ()
for layer_idx, self_attn_layer in enumerate(self_attn_legacy):
if layer_idx < len(cross_attn_legacy):
legacy_cache += (self_attn_layer + cross_attn_legacy[layer_idx],)
else:
legacy_cache += (self_attn_layer,)
return legacy_cache


# TODO:This function will be depercated in future.
class QEffHybridCache(HybridCache):
Expand Down Expand Up @@ -447,7 +537,7 @@ def __len__(self):
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
Expand Down Expand Up @@ -531,7 +621,7 @@ def __len__(self):
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
Expand Down Expand Up @@ -663,7 +753,7 @@ def __len__(self):
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
Expand Down Expand Up @@ -783,7 +873,7 @@ def __len__(self):
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
]
)


# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}

Expand Down
4 changes: 2 additions & 2 deletions QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def forward(
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
layer_past: Optional[Cache] = None,
Expand Down Expand Up @@ -190,7 +190,7 @@ def forward(
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_values=past_key_value,
comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
alibi=alibi,
Expand Down
8 changes: 4 additions & 4 deletions QEfficient/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
Expand All @@ -143,12 +143,12 @@ def forward(
query_states, key_states, cos_cached, sin_cached, position_ids
)

if past_key_value is not None:
if past_key_values is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface = eager_attention_forward

Expand Down Expand Up @@ -210,7 +210,7 @@ def forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_values=past_key_value,
comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
use_cache=use_cache,
Expand Down
10 changes: 5 additions & 5 deletions QEfficient/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
Expand All @@ -150,7 +150,7 @@ def forward(
query_states, key_states, cos_cached, sin_cached, position_ids
)

if past_key_value is not None:
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin_cached,
Expand All @@ -161,7 +161,7 @@ def forward(
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward

Expand All @@ -177,7 +177,7 @@ def forward(

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, past_key_values


class QEffGemma2DecoderLayer(Gemma2DecoderLayer):
Expand Down Expand Up @@ -227,7 +227,7 @@ def forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_values=past_key_value,
comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
Expand Down
12 changes: 6 additions & 6 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def forward(
position_embeddings: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
Expand All @@ -232,7 +232,7 @@ def forward(
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)

if past_key_value is not None:
if past_key_values is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
Expand All @@ -245,7 +245,7 @@ def forward(
cos, sin = self.rotary_emb(value_states, seq_len=self.config.max_position_embeddings)

query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
Expand All @@ -254,12 +254,12 @@ def forward(
"position_ids": position_ids,
"is_sliding": self.is_sliding,
"sliding_window_pattern": self.config.sliding_window_pattern,
"sliding_window": past_key_value.sliding_window_len,
"sliding_window": past_key_values.sliding_window_len,
}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down Expand Up @@ -330,7 +330,7 @@ def forward(
position_embeddings=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_values=past_key_value,
comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
Expand Down
Loading
Loading