Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 136 additions & 40 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,49 +442,145 @@ def update(
k_out, v_out = key_states, value_states

else:
position_ids = cache_kwargs.get("position_ids")
is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx]))
self.is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx]))

# Update the position_ids to handle the sliding window
layer_ctx_len = self.key_cache[layer_idx].shape[2]
kv_position_ids = torch.where(
(~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1)
)
# Update both caches with their respective position IDs
k_dynamic, v_dynamic = self.dynamic_cache(key_states, value_states, layer_idx, cache_kwargs)
k_hybrid, v_hybrid = self.hybrid_cache(key_states, value_states, layer_idx, cache_kwargs)

kv_position_ids = torch.where(
is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2),
(position_ids + 1) % layer_ctx_len,
kv_position_ids,
)
# Selection logic using torch.where
position_ids = cache_kwargs.get("position_ids", torch.tensor([0]))
use_hybrid = position_ids.max() >= 8192

valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1)
key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states))
value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states))
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
k_out = torch.where(use_hybrid, k_hybrid, k_dynamic)
v_out = torch.where(use_hybrid, v_hybrid, v_dynamic)

# Original Gather
ctx_len = min(layer_ctx_len, k_out.shape[2])
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
return k_out, v_out

def hybrid_cache(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the position_ids to handle the sliding window
position_ids = cache_kwargs.get("position_ids")

layer_ctx_len = self.key_cache[layer_idx].shape[2]
kv_position_ids = torch.where(
(~self.is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1)
)

kv_position_ids = torch.where(
self.is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2),
(position_ids + 1) % layer_ctx_len,
kv_position_ids,
)

valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1)
key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states))
value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states))
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
ctx_len = min(layer_ctx_len, k_out.shape[2])
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

# Rolling indices for sliding window
all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1
rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices)
final_indices = torch.where(
(self.is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices
)
k_out = CtxGatherFunc.apply(k_out, final_indices)
v_out = CtxGatherFunc.apply(v_out, final_indices)
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((self.is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)

# Rolling indices for sliding window
all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1
rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices)
final_indices = torch.where(
(is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices
)
k_out = CtxGatherFunc.apply(k_out, final_indices)
v_out = CtxGatherFunc.apply(v_out, final_indices)
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)
return k_out, v_out

def dynamic_cache(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

Return:
A tuple containing the updated key and value states.
"""
# Update the cache
position_ids = cache_kwargs.get("position_ids")

# Scatter
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Gather
ctx_len = k_out.shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

return k_out, v_out


# class QEffDualCacheWrapper:
# def __init__(self, dynamic_cache, hybrid_cache):
# self.dynamic_cache = dynamic_cache
# self.hybrid_cache = hybrid_cache

# def update(self, key_states, value_states, layer_idx, cache_kwargs):
# # Prepare separate cache_kwargs for each cache type
# cache_kwargs = {
# "batch_index": cache_kwargs.get("batch_index"),
# "position_ids": cache_kwargs.get("position_ids")
# }

# # Update both caches with their respective position IDs
# k_dynamic, v_dynamic = self.dynamic_cache.update(key_states, value_states, layer_idx, cache_kwargs)
# k_hybrid, v_hybrid = self.hybrid_cache.update(key_states, value_states, layer_idx, cache_kwargs)

# # Selection logic using torch.where
# position_ids = cache_kwargs.get("position_ids", torch.tensor([0]))
# use_hybrid = position_ids.max() >= 8192

# k_out = torch.where(use_hybrid, k_hybrid, k_dynamic)
# v_out = torch.where(use_hybrid, v_hybrid, v_dynamic)

# return k_out, v_out
13 changes: 13 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,9 +840,22 @@ def kv_offload_generate(

if vision_inputs:
vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16")

expected_patches = constants.LLAMA4_MAX_NUM_TILES
if vision_inputs["pixel_values"].shape[0] != expected_patches:
logger.info(
f"Padding pixel_values from {vision_inputs['pixel_values'].shape[0]} to {expected_patches} patches"
)
single_patch = np.expand_dims(vision_inputs["pixel_values"][0], axis=0)
while vision_inputs["pixel_values"].shape[0] < expected_patches:
vision_inputs["pixel_values"] = np.concatenate(
(vision_inputs["pixel_values"], single_patch), axis=0
)

vision_start = perf_counter()

vision_outputs = {}

if vision_inputs:
vision_outputs = vision_session.run(vision_inputs)
vision_end = perf_counter()
Expand Down
1 change: 1 addition & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get_models_dir():
# Llama4 Constants
LLAMA4_ATTENTION_CHUNK_SIZE = 8192
LLAMA4_MAX_POSITION_EMBEDDINGS = 65536
LLAMA4_MAX_NUM_TILES = 17

# Gemma3 Constant
GEMMA3_MAX_POSITION_EMBEDDINGS = 32768
Expand Down
Loading