Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
d37ef9d
added modeling for kimik2
ochougul Jan 28, 2026
426eed1
able to run kimi model, need to check accuracy
ochougul Jan 28, 2026
f3de723
bugfix
ochougul Feb 2, 2026
a3abf05
experimentation branch commit
ochougul Feb 6, 2026
a27e272
added MLA with/WO fusion, the caching for different config needs to b…
ochougul Feb 9, 2026
4e025c6
Add prefill only moe changes from kimik2 branch
mamtsing Feb 10, 2026
2c5358d
Change Cache for compressed KV and k_rope
mamtsing Feb 16, 2026
30c57e9
fix dynamic axis and output mismatch
mamtsing Feb 18, 2026
0517d93
Split kv_a_proj_with_mqa weights to get ckv and k_pe
quic-mamta Feb 24, 2026
1398388
update min_masked_attention_value, head blocking, attn in fused_forwa…
mamtsing Apr 1, 2026
8aaa56a
Add replicatekvhead transform
mamtsing Apr 1, 2026
b05ca6e
Add KV blocking and update head blocking
mamtsing Apr 1, 2026
2e44321
Update modeling_deepseek_qeff.py
mamtsing Apr 1, 2026
8c99ee0
update kv heads in example_inputs during export and kv_blocking
mamtsing Apr 1, 2026
2859c19
Add example script
mamtsing Apr 1, 2026
13da8ef
update fused forward for loops
mamtsing Apr 7, 2026
9e432dc
fix reshapes in fused_forward
mamtsing Apr 7, 2026
b625113
write_only and read_onnly for blocked_kv
mamtsing Apr 7, 2026
768ce85
removed commented code
ochougul Apr 8, 2026
4f11f58
update compressed-tensors and tiktoken version
mamtsing Apr 8, 2026
4cbb376
add all forward for h and kv blocking
mamtsing Apr 9, 2026
04fa702
removing blocking config changes
mamtsing Apr 9, 2026
900ee0b
fixed basic forward
ochougul Apr 9, 2026
f763671
fixed for kimi k2
ochougul Apr 9, 2026
dc20495
fix lint format
mamtsing Apr 10, 2026
877a2de
remove redundencies
mamtsing Apr 10, 2026
38910d3
fix no absorption for head blocking
mamtsing Apr 10, 2026
a0f0e61
Add head blocking in full pkv forward
mamtsing Apr 13, 2026
0965b5e
Merge branch 'main' into mla_perf
quic-mamta Apr 13, 2026
5b56df4
fix tests
mamtsing Apr 13, 2026
f76aeb2
Merge branch 'main' into mla_perf
quic-mamta Apr 13, 2026
4d4dd0c
fix CI errors
mamtsing Apr 13, 2026
91526ab
Merge branch 'main' into mla_perf
quic-mamta Apr 15, 2026
37dd1bb
added orig forward from snippet
mamtsing Apr 15, 2026
210d953
add support for subfunctions
mamtsing Apr 16, 2026
09fd219
Merge branch 'main' into mla_perf
quic-mamta Apr 21, 2026
6b04a57
use generalized kv blocking
mamtsing Apr 21, 2026
2126672
use generalised infra for h_blocking with ckv and full kv
mamtsing Apr 22, 2026
6d070ac
Merge branch 'main' into mla_perf
quic-mamta Apr 22, 2026
c9fd13e
remove redundant files
mamtsing Apr 23, 2026
5c23f53
remove env variables
mamtsing Apr 26, 2026
d05bbf8
Merge branch 'main' into mla_perf
quic-mamta Apr 26, 2026
8e7ddbe
simplify mla_absorption_config
mamtsing Apr 26, 2026
c966cc9
address review comments
mamtsing Apr 27, 2026
01867ea
fix mla_absorption_config
mamtsing Apr 27, 2026
4f7d28c
added head expansion for kv blocking
ochougul Apr 27, 2026
fb20758
Merge branch 'mla_perf' of github.com:quic/efficient-transformers int…
ochougul Apr 27, 2026
b05aa06
fixed kv with kv replication
ochougul Apr 27, 2026
c24d851
minor fix
ochougul Apr 27, 2026
bdbf50c
Added all changes of layer wise for kimi model
abhishek-singh591 Apr 28, 2026
4c069c3
Made minor fix
abhishek-singh591 Apr 28, 2026
bf1b9e3
Update run.py
abhishek-singh591 Apr 28, 2026
ab312d0
Update modeling_qeff.py
abhishek-singh591 Apr 28, 2026
7cd93e3
Made minor fix
abhishek-singh591 Apr 28, 2026
067ce7d
Made minor fix
abhishek-singh591 Apr 28, 2026
30c8e24
Made minor fix
abhishek-singh591 Apr 28, 2026
b211a9a
Made minor fix
abhishek-singh591 Apr 28, 2026
106f065
Added thread pool for loading the QPC
abhishek-singh591 Apr 29, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
453 changes: 259 additions & 194 deletions QEfficient/base/modeling_qeff.py

Large diffs are not rendered by default.

88 changes: 78 additions & 10 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,80 @@ def apply(cls, model: ModelProto) -> bool:
return op_applied


class RemovePrefix(BaseOnnxTransform):
@classmethod
def apply(cls, model: ModelProto) -> bool:
graph = model.graph
renamed = False

def strip_prefix(name: str) -> str:
parts = name.rsplit("/", 1)
return parts[1] if len(parts) == 2 else parts[0]

input_names = []
for i, inputs in enumerate(graph.input):
original = inputs.name
new = strip_prefix(original)
if new != original:
renamed = True
inputs.name = new
graph.input[i].name = new
input_names.append(new)

input_name_set = set(input_names)
output_rename_map = {}

# Rename model graph outputs and keep mapping so producer/consumer edges can be fixed.
for out in graph.output:
original = out.name
new = strip_prefix(original)
if new != original:
out.name = new
output_rename_map[original] = new
renamed = True

for node in graph.node:
for i, out in enumerate(node.output):
if out in output_rename_map and output_rename_map[out] != out:
node.output[i] = output_rename_map[out]
renamed = True

new_inputs = []
for s in node.input:
# Keep node inputs in sync for renamed model outputs.
if s in output_rename_map:
new_inputs.append(output_rename_map[s])
continue

if s in input_name_set:
new_inputs.append(s)
continue

replaced = s
if "/" in s:
tail = s.rsplit("/", 1)[1]
if tail in input_name_set:
replaced = tail
new_inputs.append(replaced)

for idx in range(len(node.input)):
if node.input[idx] != new_inputs[idx]:
node.input[idx] = new_inputs[idx]
renamed = True

return renamed


class RenameFunctionOutputsTransform(BaseOnnxTransform):
"""Rename outputs of decoder-related functions for better clarity."""

@classmethod
def apply(cls, model: ModelProto) -> bool:
def apply(cls, model: ModelProto, layer_idx=0) -> bool:
graph = model.graph
op_type_to_func = {f.name: f for f in model.functions}
decoder_patterns = ["DecoderLayer", "Block", "Layer"]
renamed = False
model_out_map = {v.name: i for i, v in enumerate(graph.output)}
layer_idx = 0

for node in graph.node:
if any(p in node.name or p in node.op_type for p in decoder_patterns):
Expand All @@ -150,13 +213,16 @@ def apply(cls, model: ModelProto) -> bool:
if "_InternalRetainedState" in out_name:
renamed = True
orig = node.output[i]
new = (
f"past_key.{layer_idx}_RetainedState"
if "key" in out_name
else f"past_value.{layer_idx}_RetainedState"
if "value" in out_name
else orig
)
if "key" in out_name:
new = f"past_key.{layer_idx}_RetainedState"
elif "value" in out_name:
new = f"past_value.{layer_idx}_RetainedState"
elif "compressed_kv" in out_name:
new = f"compressed_kv.{layer_idx}_RetainedState"
elif "k_pe" in out_name:
new = f"k_pe.{layer_idx}_RetainedState"
else:
new = orig
node.output[i] = new
if orig in model_out_map:
graph.output[model_out_map[orig]].name = new
Expand Down Expand Up @@ -275,7 +341,9 @@ def _set_external_data(tensor, file_name):
applied[CustomOpTransform] = CustomOpTransform.apply(model)

if RenameFunctionOutputsTransform in requested:
applied[RenameFunctionOutputsTransform] = RenameFunctionOutputsTransform.apply(model)
applied[RenameFunctionOutputsTransform] = RenameFunctionOutputsTransform.apply(
model, layer_idx=kwargs.get("layer_idx", 0)
)

if AdapterWeightsToInputsTransform in requested:
applied[AdapterWeightsToInputsTransform] = AdapterWeightsToInputsTransform.apply(model, **kwargs)
Expand Down
72 changes: 71 additions & 1 deletion QEfficient/blocking/attention_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@

from dataclasses import dataclass
from enum import Enum
from typing import Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional

import torch
from transformers.cache_utils import Cache

from QEfficient.blocking.blocked_attention_forwards import (
blocked_bhqkv_attention_forward,
blocked_h_attention_forward,
blocked_h_mla_attention_forward,
blocked_hqkv_attention_forward,
blocked_kv_attention_forward,
blocked_kv_mla_attention_forward,
blocked_q_attention_forward,
blocked_qkv_attention_forward,
)
Expand Down Expand Up @@ -57,6 +59,11 @@ def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool:
BlockingMode.BHQKV: blocked_bhqkv_attention_forward,
}

_STRATEGIES_MLA: Dict[BlockingMode, Callable] = {
BlockingMode.KV: blocked_kv_mla_attention_forward,
BlockingMode.H: blocked_h_mla_attention_forward,
}


# helper function needed both in generic blocked approach and in other modeling files for non-blocked approach
def past_key_value_update(
Expand Down Expand Up @@ -160,3 +167,66 @@ def generic_blocked_attention_interface(
)

return attn_output, attn_weights


def generic_blocked_mla_attention_interface(
module,
attention_mask: Optional[torch.Tensor],
scaling: float,
mla_absorption: Dict[str, Any],
blocking_config: AttentionBlockingConfig,
query: Optional[torch.Tensor] = None,
q_a_proj_out: Optional[torch.Tensor] = None,
fusedqk: Optional[torch.Tensor] = None,
q_nope: Optional[torch.Tensor] = None,
q_pe: Optional[torch.Tensor] = None,
kva: Optional[torch.Tensor] = None,
k_pe: Optional[torch.Tensor] = None,
per_head_q_up: Optional[torch.Tensor] = None,
per_head_k_up: Optional[torch.Tensor] = None,
per_head_v_up: Optional[torch.Tensor] = None,
per_head_k_up_normal: Optional[torch.Tensor] = None,
layer_idx: Optional[int] = None,
compressed_kvs: Optional[torch.Tensor] = None,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_seen_tokens: Optional[int] = None,
non_blocked_forward: Callable = None,
score_mod: Optional[Callable] = None,
position_bias: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
**kwargs,
):
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
mla_blocking_strategy = _STRATEGIES_MLA.get(blocking_config.mode)
attn_output, attn_weights = mla_blocking_strategy(
module=module,
query=query,
q_a_proj_out=q_a_proj_out,
fusedqk=fusedqk,
q_nope=q_nope,
q_pe=q_pe,
kva=kva,
k_pe=k_pe,
per_head_q_up=per_head_q_up,
per_head_k_up=per_head_k_up,
per_head_v_up=per_head_v_up,
per_head_k_up_normal=per_head_k_up_normal,
attention_mask=attention_mask,
scaling=scaling,
cache_kwargs=cache_kwargs,
layer_idx=layer_idx,
compressed_kvs=compressed_kvs,
mla_absorption=mla_absorption,
num_kv_blocks=blocking_config.num_kv_blocks,
num_q_blocks=blocking_config.num_q_blocks,
head_block_size=blocking_config.head_block_size,
num_batch_blocks=blocking_config.num_batch_blocks,
score_mod=score_mod,
position_bias=position_bias,
sinks=sinks,
)

return attn_output, attn_weights
Loading
Loading