Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,126 @@ def update(
# This is a hack for now, until we get to merging this code with HybridCache class,
# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and
# ours are made to work with AIC
class QEffSlidingWindowCache:
def __init__(self, config, batch_size, max_cache_len, sliding_window_len):
self.max_cache_len = max_cache_len
self.batch_size = batch_size
self.sliding_window_len = sliding_window_len
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

@classmethod
def from_legacy_cache(
cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "HybridCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls(
config,
batch_size=past_key_values[0][0].shape[0],
max_cache_len=past_key_values[config.sliding_window_pattern - 1][0].shape[2],
sliding_window_len=past_key_values[0][0].shape[2],
)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
k_out, v_out = key_states, value_states
else:
position_ids = cache_kwargs.get("position_ids")
is_sliding_layer = cache_kwargs.get("is_sliding")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs

if is_sliding_layer:
sliding_window_len = self.key_cache[layer_idx].shape[2]
kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window_len)
else:
kv_position_ids = position_ids

if batch_index is not None:
if torch.onnx.is_in_onnx_export():
invalid_scatter_index = torch.iinfo(torch.int32).max
scatter_position_ids = torch.where(kv_position_ids < 0, invalid_scatter_index, kv_position_ids)
else:
scatter_position_ids = kv_position_ids
self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
)
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
)
else:
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
if is_sliding_layer:
ctx_len = self.key_cache[layer_idx].shape[2]
else:
ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2])

ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)

v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out


class QEffHybridCacheForGPTOSS:
def __init__(self, config, batch_size, max_cache_len, sliding_window_len):
self.max_cache_len = max_cache_len
Expand Down
13 changes: 9 additions & 4 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)

from QEfficient.customop.rms_norm import CustomRMSNorm
from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.cache_utils import QEffSlidingWindowCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils import constants
from QEfficient.utils._utils import IOInfo
Expand Down Expand Up @@ -254,6 +254,7 @@ def forward(
"position_ids": position_ids,
"is_sliding": self.is_sliding,
"sliding_window_pattern": self.config.sliding_window_pattern,
"sliding_window": past_key_value.sliding_window_len,
}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
Expand Down Expand Up @@ -311,10 +312,12 @@ def forward(
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0
# past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0
if self.self_attn.is_sliding:
attention_mask = _create_causal_mask(
position_ids=position_ids, target_length=past_seen_tokens, sliding_window=self.config.sliding_window
position_ids=position_ids,
target_length=past_key_value.sliding_window_len,
sliding_window=past_key_value.sliding_window_len,
)
else:
attention_mask = _create_causal_mask(
Expand Down Expand Up @@ -401,7 +404,9 @@ def forward(

if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
# return_legacy_cache = True
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
past_key_values = QEffSlidingWindowCache.from_legacy_cache(
config=self.config, past_key_values=past_key_values
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
Expand Down
26 changes: 16 additions & 10 deletions examples/image_text_to_text/models/gemma_vision/gemma3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,30 @@
#
# -----------------------------------------------------------------------------

import os

import torch
import transformers
from transformers import AutoConfig, AutoProcessor

from QEfficient import QEFFAutoModelForImageTextToText

# Change model_id to "google/gemma-3-27b-it" for 27B model
model_id = "google/gemma-3-4b-it"
model_id = "google/gemma-3-27b-it"

config = AutoConfig.from_pretrained(model_id)

# For Testing Purpose Only
config.text_config.num_hidden_layers = 1
config.vision_config.num_hidden_layers = 2
# For Testing Purpose Only atleast 6 layers are required
# config.text_config.num_hidden_layers = 6
# config.vision_config.num_hidden_layers = 6

tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_id)

# Path to Node Precision Info YAML file
npi_file_path = "configs/fp32_nodes_gemma3_27b.yaml"
npi_file_full_path = os.path.join(os.getcwd(), npi_file_path)

# For single QPC: kv_offload=False, For dual QPC: kv_offload=True
qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
model_id, config=config, attn_implementation="eager", kv_offload=True
Expand All @@ -44,7 +50,7 @@
aic_enable_depth_first=True,
skip_vision=True,
mos=1,
node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_4b.yaml", # Change to fp32_nodes_gemma3_27b.yaml for 27B model
node_precision_info=npi_file_full_path,
)

messages = [
Expand All @@ -64,7 +70,7 @@
return_tensors="pt",
)

output = qeff_model.generate(inputs=inputs, generation_len=100)
output = qeff_model.generate(inputs=inputs, generation_len=2000)
print(tokenizer.batch_decode(output.generated_ids))
print(output)

Expand All @@ -75,12 +81,12 @@
ctx_len=3072,
img_size=896,
num_cores=16,
num_devices=1,
num_devices=4,
mxfp6_matmul=False,
mxint8_kv_cache=False,
aic_enable_depth_first=True,
mos=1,
node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_4b.yaml", # Change to fp32_nodes_gemma3_27b.yaml for 27B model
node_precision_info=npi_file_full_path,
)

### IMAGE + TEXT ###
Expand All @@ -93,7 +99,7 @@
"role": "user",
"content": [
{"type": "image", "url": image_url},
{"type": "text", "text": "Can you describe the image in detail."},
{"type": "text", "text": "Describe this image in details."},
],
},
]
Expand All @@ -106,6 +112,6 @@
return_tensors="pt",
)
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
output = qeff_model.generate(inputs=inputs, generation_len=100)
output = qeff_model.generate(inputs=inputs, generation_len=2000)
print(tokenizer.batch_decode(output.generated_ids, skip_special_tokens=True))
print(output)
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@
],
[
"Can you describe the image in detail?",
"What are the objects in the image?",
"What is the main subject of the image?",
"What colors are predominant in the image?",
"Can you describe the image in detail?",
"Can you describe the image in detail?",
"Can you describe the image in detail?",
],
1,
6,
4,
),
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
896,
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png",
"Can you describe the image in detail.",
1,
6,
),
(
"google/gemma-3-4b-it",
Expand All @@ -110,7 +110,7 @@
896,
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png",
"Can you describe the image in detail.",
1,
6,
),
(
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
Expand Down