diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 11dea4ec5..654b3862d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1036,7 +1036,7 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) - def prefill( + def __update_prefill_transform( self, enable: Optional[bool] = True, enable_chunking: Optional[bool] = False, @@ -1096,10 +1096,10 @@ def export( "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" ) self.hash_params["prefill_only"] = True - self.prefill(enable=True, enable_chunking=enable_chunking) + self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) else: self.hash_params["prefill_only"] = False - self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) return self._export( inputs, @@ -2699,7 +2699,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def prefill( + def __update_prefill_transform( self, enable: Optional[bool] = True, enable_chunking: Optional[bool] = False, @@ -2997,7 +2997,7 @@ def export( raise NotImplementedError( "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" ) - self.prefill(enable=True, enable_chunking=enable_chunking) + self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) self.hash_params.pop("retain_full_kv", None) seq_len = self.get_seq_len_and_handle_specialized_prefill_model( prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking @@ -3008,7 +3008,7 @@ def export( else seq_len ) else: - self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) self.hash_params.pop("prefill_only", None) self.hash_params.pop("NUM_Q_BLOCKS", None) self.hash_params.pop("NUM_FFN_BLOCKS", None) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 076e4cdbd..67c7daf5e 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -628,29 +628,32 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens router_logits = self.gate(x) # [T, E] prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) - top_w, top_i = torch.topk(prob, self.top_k, dim=-1) # [T, k], [T, k] + top_w, top_i = torch.topk(prob, self.top_k, dim=-1) top_w = top_w / torch.einsum("bi->b", top_w)[:, None] top_w = top_w.to(hidden_states.dtype) - - # gate_up_proj: [E, H, 2I], down_proj: [E, I, H] - W_up = self.experts.gate_up_proj - W_dn = self.experts.down_proj - E, H_w, twoI = W_up.shape - I2 = twoI // 2 - routing_weights = torch.zeros_like(prob, dtype=hidden_states.dtype) # [T, E] + routing_weights = torch.zeros((T, self.num_experts), dtype=x.dtype) routing_weights.scatter_(1, top_i, top_w) - expert_out = x.new_zeros((T, H)) - for e in range(E): - rw = routing_weights[:, e].unsqueeze(-1) # [T, 1] - # Split fused [H, 2I] -> [H, I] + [H, I] - W_gate_e = W_up[e, :, :I2] - W_up_e = W_up[e, :, I2:] - W_dn_e = W_dn[e, :, :] - gate = x @ W_gate_e - up = x @ W_up_e - down = (up * act(gate)) @ W_dn_e - expert_out.add_(down * rw) - return expert_out.view(B, S, H), router_logits + + expert_out = torch.zeros_like(x, dtype=x.dtype) + + for e in range(self.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) + + W_gate_up_e = self.experts.gate_up_proj[e] # [H, 2I] + W_dn_e = self.experts.down_proj[e] # [I, H] + gate_up = x @ W_gate_up_e # [T, 2I] + + I2 = gate_up.shape[-1] // 2 + gate = gate_up[:, :I2] # [T, I] + up = gate_up[:, I2:] # [T, I] + intermediate = up * act(gate) + down = intermediate @ W_dn_e + masked_down = torch.where( + routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out, dtype=down.dtype) + ) # TODO: verify and remove + expert_out += masked_down + expert_out = expert_out.to(x.dtype).view(B, S, H) + return expert_out, router_logits class QEffQwen3VLMoeModel(Qwen3VLMoeModel): diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 897eea350..6e3c43951 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -23,7 +23,7 @@ # For faster execution user can run with lesser layers, For Testing Purpose Only # config.vision_config.depth = 9 -# config.text_config.num_hidden_layers = 1 +# config.text_config.num_hidden_layers = 6 # config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( @@ -47,8 +47,10 @@ num_cores=16, num_devices=1, mos=1, + mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=skip_vision, + split_retained_state_io=True, skip_lang=True, use_onnx_subfunctions=True, ) @@ -57,6 +59,8 @@ batch_size=BS, prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, + height=354, + width=536, num_cores=16, num_devices=1, mxfp6_matmul=True, @@ -76,6 +80,8 @@ batch_size=BS, prefill_seq_len=1, ctx_len=CTX_LEN, + height=354, + width=536, num_cores=16, num_devices=1, mxfp6_matmul=True, @@ -111,10 +117,11 @@ "role": "user", "content": [ {"type": "image", "image": image}, - {"type": "text", "text": "Descibe all the colors seen in the image."}, + {"type": "text", "text": "Describe all the colors seen in the image."}, ], }, ] + vision_session = QAICInferenceSession(vision_qpc_path.get("vision_qpc_path")) messages = [messages] * BS @@ -163,28 +170,12 @@ vision_inputs_fp16 = {"pixel_values", "image_masks"} vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) -if not skip_vision: - vision_session = QAICInferenceSession(vision_qpc_path.get("vision_qpc_path")) - vision_start = perf_counter() vision_outputs = {} if vision_inputs: vision_outputs = vision_session.run(vision_inputs) vision_end = perf_counter() -if not skip_vision: - vision_session.deactivate() - -lang_prefill_session.activate() -# Skip inputs/outputs -lang_prefill_session.skip_buffers( - [ - x - for x in lang_prefill_session.input_names + lang_prefill_session.output_names - if x.startswith("past_") or x.endswith("_RetainedState") - ] -) - lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} if "position_ids" in inputs: lang_inputs["position_ids"] = inputs["position_ids"] @@ -196,54 +187,67 @@ lang_inputs["image_idx"] = np.array([[0]]) +if not skip_vision: + lang_inputs["vision_embeds"] = vision_outputs["vision_embeds"] + lang_inputs["deepstack_features"] = vision_outputs["deepstack_features"] # RUN prefill lang_start = perf_counter() +lang_prefill_session.set_buffers(vision_outputs) all_outputs = [] chunk_inputs = lang_inputs.copy() for i in range(num_chunks): chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] chunk_inputs["position_ids"] = lang_inputs["position_ids"][..., i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] outputs = lang_prefill_session.run(chunk_inputs) + for i in range(config.text_config.num_hidden_layers): + chunk_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"] + chunk_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"] chunk_inputs["image_idx"] = outputs["image_idx_output"] prefill_time = perf_counter() - lang_start + vision_end - vision_start print(f"Prefill time : {prefill_time:.2f} secs") -lang_prefill_session.deactivate() -lang_decode_session.activate() -# Skip inputs/outputs -lang_decode_session.skip_buffers( - [ - x - for x in lang_decode_session.input_names + lang_decode_session.output_names - if x.startswith("past_") or x.endswith("_RetainedState") - ] -) - -lang_decode_session.set_buffers(outputs) - all_outputs.append(np.argmax(outputs["logits"])) decode_inputs = { "input_ids": np.argmax(outputs["logits"]).reshape(1, 1), "position_ids": np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1, } +for i in range(config.text_config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"] + +decode_inputs["image_idx"] = outputs["image_idx_output"] +decode_inputs["vision_embeds"] = outputs["vision_embeds_RetainedState"] +decode_inputs["deepstack_features"] = outputs["deepstack_features_RetainedState"] + st = perf_counter() decode_out = lang_decode_session.run(decode_inputs) print(f"time for first run of decode with KV as input = {perf_counter() - st} sec\n") all_outputs.append(np.argmax(decode_out["logits"])) -pos_id = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 +pos_id = np.max(decode_inputs["position_ids"], axis=-1, keepdims=True) + 1 loop_decode_inputs = { "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), "position_ids": pos_id, } +for i in range(config.text_config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] +loop_decode_inputs["image_idx"] = decode_out["image_idx_output"] +loop_decode_inputs["vision_embeds"] = decode_out["vision_embeds_RetainedState"] +loop_decode_inputs["deepstack_features"] = decode_out["deepstack_features_RetainedState"] + + st = perf_counter() for i in range(generation_len - 2): decode_out = lang_decode_session.run(loop_decode_inputs) all_outputs.append(np.argmax(decode_out["logits"])) pos_id += 1 + for j in range(config.text_config.num_hidden_layers): + loop_decode_inputs[f"past_key.{j}"] = decode_out[f"past_key.{j}_RetainedState"] + loop_decode_inputs[f"past_value.{j}"] = decode_out[f"past_value.{j}_RetainedState"] loop_decode_inputs.update( { "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), @@ -251,6 +255,5 @@ } ) ft = perf_counter() - print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") print(f"\noutput\n{tokenizer.decode(all_outputs)}")