Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions QEfficient/transformers/models/modeling_auto.py
Comment thread
tv-karthikeya marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))

def prefill(
def __update_prefill_transform(
self,
enable: Optional[bool] = True,
enable_chunking: Optional[bool] = False,
Expand Down Expand Up @@ -1096,10 +1096,10 @@ def export(
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
)
self.hash_params["prefill_only"] = True
self.prefill(enable=True, enable_chunking=enable_chunking)
self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking)
else:
self.hash_params["prefill_only"] = False
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False))

return self._export(
inputs,
Expand Down Expand Up @@ -2699,7 +2699,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):

_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def prefill(
def __update_prefill_transform(
self,
enable: Optional[bool] = True,
enable_chunking: Optional[bool] = False,
Expand Down Expand Up @@ -2997,7 +2997,7 @@ def export(
raise NotImplementedError(
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
)
self.prefill(enable=True, enable_chunking=enable_chunking)
self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking)
self.hash_params.pop("retain_full_kv", None)
seq_len = self.get_seq_len_and_handle_specialized_prefill_model(
prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking
Expand All @@ -3008,7 +3008,7 @@ def export(
else seq_len
)
else:
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.hash_params.pop("prefill_only", None)
self.hash_params.pop("NUM_Q_BLOCKS", None)
self.hash_params.pop("NUM_FFN_BLOCKS", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -628,29 +628,32 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens

router_logits = self.gate(x) # [T, E]
prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype)
top_w, top_i = torch.topk(prob, self.top_k, dim=-1) # [T, k], [T, k]
top_w, top_i = torch.topk(prob, self.top_k, dim=-1)
top_w = top_w / torch.einsum("bi->b", top_w)[:, None]
top_w = top_w.to(hidden_states.dtype)

# gate_up_proj: [E, H, 2I], down_proj: [E, I, H]
W_up = self.experts.gate_up_proj
W_dn = self.experts.down_proj
E, H_w, twoI = W_up.shape
I2 = twoI // 2
routing_weights = torch.zeros_like(prob, dtype=hidden_states.dtype) # [T, E]
routing_weights = torch.zeros((T, self.num_experts), dtype=x.dtype)
routing_weights.scatter_(1, top_i, top_w)
expert_out = x.new_zeros((T, H))
for e in range(E):
rw = routing_weights[:, e].unsqueeze(-1) # [T, 1]
# Split fused [H, 2I] -> [H, I] + [H, I]
W_gate_e = W_up[e, :, :I2]
W_up_e = W_up[e, :, I2:]
W_dn_e = W_dn[e, :, :]
gate = x @ W_gate_e
up = x @ W_up_e
down = (up * act(gate)) @ W_dn_e
expert_out.add_(down * rw)
return expert_out.view(B, S, H), router_logits

expert_out = torch.zeros_like(x, dtype=x.dtype)

for e in range(self.num_experts):
routing_weight = routing_weights[:, e].unsqueeze(-1)

W_gate_up_e = self.experts.gate_up_proj[e] # [H, 2I]
W_dn_e = self.experts.down_proj[e] # [I, H]
gate_up = x @ W_gate_up_e # [T, 2I]

I2 = gate_up.shape[-1] // 2
gate = gate_up[:, :I2] # [T, I]
up = gate_up[:, I2:] # [T, I]
intermediate = up * act(gate)
down = intermediate @ W_dn_e
masked_down = torch.where(
routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out, dtype=down.dtype)
) # TODO: verify and remove
expert_out += masked_down
expert_out = expert_out.to(x.dtype).view(B, S, H)
return expert_out, router_logits


class QEffQwen3VLMoeModel(Qwen3VLMoeModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# For faster execution user can run with lesser layers, For Testing Purpose Only
# config.vision_config.depth = 9
# config.text_config.num_hidden_layers = 1
# config.text_config.num_hidden_layers = 6
# config.vision_config.deepstack_visual_indexes = [8]

qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
Expand All @@ -47,8 +47,10 @@
num_cores=16,
num_devices=1,
mos=1,
mxfp6_matmul=True,
aic_enable_depth_first=True,
skip_vision=skip_vision,
split_retained_state_io=True,
skip_lang=True,
use_onnx_subfunctions=True,
)
Expand All @@ -57,6 +59,8 @@
batch_size=BS,
prefill_seq_len=PREFILL_SEQ_LEN,
ctx_len=CTX_LEN,
height=354,
width=536,
num_cores=16,
num_devices=1,
mxfp6_matmul=True,
Expand All @@ -76,6 +80,8 @@
batch_size=BS,
prefill_seq_len=1,
ctx_len=CTX_LEN,
height=354,
width=536,
num_cores=16,
num_devices=1,
mxfp6_matmul=True,
Expand Down Expand Up @@ -111,10 +117,11 @@
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Descibe all the colors seen in the image."},
{"type": "text", "text": "Describe all the colors seen in the image."},
],
},
]
vision_session = QAICInferenceSession(vision_qpc_path.get("vision_qpc_path"))


messages = [messages] * BS
Expand Down Expand Up @@ -163,28 +170,12 @@
vision_inputs_fp16 = {"pixel_values", "image_masks"}
vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs})

if not skip_vision:
vision_session = QAICInferenceSession(vision_qpc_path.get("vision_qpc_path"))

vision_start = perf_counter()
vision_outputs = {}
if vision_inputs:
vision_outputs = vision_session.run(vision_inputs)
vision_end = perf_counter()

if not skip_vision:
vision_session.deactivate()

lang_prefill_session.activate()
# Skip inputs/outputs
lang_prefill_session.skip_buffers(
[
x
for x in lang_prefill_session.input_names + lang_prefill_session.output_names
if x.startswith("past_") or x.endswith("_RetainedState")
]
)

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
if "position_ids" in inputs:
lang_inputs["position_ids"] = inputs["position_ids"]
Expand All @@ -196,61 +187,73 @@

lang_inputs["image_idx"] = np.array([[0]])

if not skip_vision:
lang_inputs["vision_embeds"] = vision_outputs["vision_embeds"]
lang_inputs["deepstack_features"] = vision_outputs["deepstack_features"]

# RUN prefill
lang_start = perf_counter()
lang_prefill_session.set_buffers(vision_outputs)
all_outputs = []
chunk_inputs = lang_inputs.copy()
for i in range(num_chunks):
chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
chunk_inputs["position_ids"] = lang_inputs["position_ids"][..., i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
outputs = lang_prefill_session.run(chunk_inputs)
for i in range(config.text_config.num_hidden_layers):
chunk_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"]
chunk_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"]
chunk_inputs["image_idx"] = outputs["image_idx_output"]
prefill_time = perf_counter() - lang_start + vision_end - vision_start
print(f"Prefill time : {prefill_time:.2f} secs")

lang_prefill_session.deactivate()
lang_decode_session.activate()
# Skip inputs/outputs
lang_decode_session.skip_buffers(
[
x
for x in lang_decode_session.input_names + lang_decode_session.output_names
if x.startswith("past_") or x.endswith("_RetainedState")
]
)

lang_decode_session.set_buffers(outputs)

all_outputs.append(np.argmax(outputs["logits"]))
decode_inputs = {
"input_ids": np.argmax(outputs["logits"]).reshape(1, 1),
"position_ids": np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1,
}

for i in range(config.text_config.num_hidden_layers):
decode_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"]
decode_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"]

decode_inputs["image_idx"] = outputs["image_idx_output"]
decode_inputs["vision_embeds"] = outputs["vision_embeds_RetainedState"]
decode_inputs["deepstack_features"] = outputs["deepstack_features_RetainedState"]

st = perf_counter()
decode_out = lang_decode_session.run(decode_inputs)
print(f"time for first run of decode with KV as input = {perf_counter() - st} sec\n")

all_outputs.append(np.argmax(decode_out["logits"]))
pos_id = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1
pos_id = np.max(decode_inputs["position_ids"], axis=-1, keepdims=True) + 1
loop_decode_inputs = {
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
"position_ids": pos_id,
}

for i in range(config.text_config.num_hidden_layers):
loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]
loop_decode_inputs["image_idx"] = decode_out["image_idx_output"]
loop_decode_inputs["vision_embeds"] = decode_out["vision_embeds_RetainedState"]
loop_decode_inputs["deepstack_features"] = decode_out["deepstack_features_RetainedState"]


st = perf_counter()
for i in range(generation_len - 2):
decode_out = lang_decode_session.run(loop_decode_inputs)
all_outputs.append(np.argmax(decode_out["logits"]))
pos_id += 1
for j in range(config.text_config.num_hidden_layers):
loop_decode_inputs[f"past_key.{j}"] = decode_out[f"past_key.{j}_RetainedState"]
loop_decode_inputs[f"past_value.{j}"] = decode_out[f"past_value.{j}_RetainedState"]
loop_decode_inputs.update(
{
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
"position_ids": pos_id,
}
)
ft = perf_counter()

print(f"decode tok/sec={(generation_len - 2) / (ft - st)}")
print(f"\noutput\n{tokenizer.decode(all_outputs)}")
Loading