From 5d8ae39f3250ff004b16fc9f6bb36bcb9272f149 Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Mon, 20 Apr 2026 23:45:55 -0700 Subject: [PATCH 1/3] Added MDP generation to QEff Compile Signed-off-by: Mohit Mehta --- QEfficient/base/modeling_qeff.py | 52 ++++-- QEfficient/compile/mdp_generator.py | 249 ++++++++++++++++++++++++++++ 2 files changed, 291 insertions(+), 10 deletions(-) create mode 100644 QEfficient/compile/mdp_generator.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6c3cc7993..c6e0cfcac 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -462,6 +462,7 @@ def _compile( specializations: Optional[List[Dict[str, int]]] = None, custom_io: Optional[Dict[str, str]] = None, mdp_ts_num_devices: int = 1, + mdp_num_partitions: int = 1, num_speculative_tokens: Optional[int] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, @@ -483,6 +484,11 @@ def _compile( :specializations (list): List of specializations to compile for :custom_io (dict): Custom IO to specify the input and outputs in different formats than default :mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing. + :mdp_num_partitions (int): Number of pipeline-parallel partitions for disaggregated prefill serving. + When > 1, the ONNX graph is read directly to generate a fully-populated MDP partition + config (nodeList per partition) without requiring a compiler round-trip. + Ignored when ``mdp_load_partition_config`` is already provided in compiler_options. + Defaults to 1 (template / tensor-slice MDP, existing behaviour). :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` :qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. ``Defaults to None.`` @@ -546,22 +552,47 @@ def _compile( + [f"-m={onnx_path}"] ) - # MDP partition config: prioritize dump over load - mdp_dump_json_path = compiler_options.pop("mdp_dump_partition_config", None) + # MDP partition config selection (three priorities, highest first): + # 1. User explicitly provides a pre-built MDP JSON to load. + # 2. Disaggregated (pipeline-parallel) MDP — generate from ONNX topsort. + # 3. Template (tensor-slice) MDP — single partition, nodeList absent. mdp_ts_json_path = compiler_options.pop("mdp_load_partition_config", None) + # Silently discard any stale mdp_dump_partition_config key that callers + # may still pass; the compiler-round-trip dump path is no longer supported. + compiler_options.pop("mdp_dump_partition_config", None) mdp_ts_json = None - if mdp_dump_json_path: - if mdp_ts_json_path: - logger.warning( - "Loading and Dumping partition is not supported at the same time. Prioritizing dump config over load config!" - ) - command.append(f"-mdp-dump-partition-config={mdp_dump_json_path}") - elif mdp_ts_json_path: + if mdp_ts_json_path: command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") mdp_ts_json = load_json(str(mdp_ts_json_path)) + elif mdp_num_partitions > 1: + # Disaggregated (pipeline-parallel) MDP: generate a fully-populated + # nodeList per partition directly from the ONNX graph — no compiler + # round-trip required. + from QEfficient.compile.mdp_generator import generate_disagg_mdp_partition_config + + num_layers = getattr(self, "num_layers", None) + if num_layers is None: + raise AttributeError("Model does not expose 'num_layers'. Cannot generate disagg MDP partition config.") + num_cores = compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) + logger.info( + f"Generating disagg MDP partition config from ONNX: " + f"num_devices={mdp_ts_num_devices}, num_partitions={mdp_num_partitions}, " + f"num_layers={num_layers}, num_cores={num_cores}" + ) + mdp_ts_json = generate_disagg_mdp_partition_config( + onnx_path=str(onnx_path), + num_devices=mdp_ts_num_devices, + num_partitions=mdp_num_partitions, + num_layers=num_layers, + num_cores=num_cores, + ) + mdp_ts_json_path = compile_dir / f"mdp_disagg_{mdp_ts_num_devices}d_{mdp_num_partitions}p.json" + create_json(str(mdp_ts_json_path), mdp_ts_json) + command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") elif mdp_ts_num_devices > 1: - # Generate mdp config only if neither dump nor load is provided and num_devices > 1 + # Template (tensor-slice) MDP: single partition, empty nodeList. + # Used when PP is disabled (stages=1). Compiler fills the nodeList. mdp_ts_json = generate_mdp_partition_config( mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) ) @@ -586,6 +617,7 @@ def _compile( "specializations": specializations, "custom_io": custom_io, "mdp_ts_num_devices": mdp_ts_num_devices, + "mdp_num_partitions": mdp_num_partitions, "mdp_ts_json": mdp_ts_json, "num_speculative_tokens": num_speculative_tokens, "prefill_only": prefill_only, diff --git a/QEfficient/compile/mdp_generator.py b/QEfficient/compile/mdp_generator.py new file mode 100644 index 000000000..58a9ffe71 --- /dev/null +++ b/QEfficient/compile/mdp_generator.py @@ -0,0 +1,249 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +"""MDP generator for disaggregated prefill serving (PP-enabled, TS-enabled, stages>1).""" + +from typing import Any, Dict, List, Optional, Set +import onnx +import logging + +logger = logging.getLogger(__name__) + + +def _get_compiler_folded_nodes(graph) -> Set[str]: + """Return node names the compiler will fold away during ONNX import. + + Mirrors computeIsConstantFoldable() in ONNXModelLoader.cpp: a node is + foldable if every one of its inputs is a compile-time constant (initializer, + Constant op output, or output of another foldable node). Folded nodes are + absent from the compiler IR, so including them in nodeList is harmless but + excluding them produces a cleaner MDP closer to the compiler dump. + + Op types that the compiler never folds (ProtobufLoader.cpp:68): + Loop, Const, Identity, If, DequantizeLinear + """ + # const_values: output tensor names whose value is known at compile time. + # Seeded with all initializer names (model weights / constants). + const_values: Set[str] = {init.name for init in graph.initializer} + + # Constant op outputs are trivially compile-time constants; collect them + # upfront so the fixed-point loop below only needs one pass for everything else. + for node in graph.node: + if node.op_type == "Constant": + const_values.update(out for out in node.output if out) + + # Never-folded op types (compiler explicitly skips these - ProtobufLoader.cpp:68). + _NEVER_FOLD = frozenset({"Loop", "Const", "Identity", "If", "DequantizeLinear"}) + + # Keep marking nodes foldable until no new ones are found. + foldable_nodes: Set[str] = set() + while(True): + changed = False + for node in graph.node: + if not node.name or node.name in foldable_nodes: + continue + if node.op_type in _NEVER_FOLD or not node.input: + continue + if all(inp in const_values for inp in node.input if inp): + foldable_nodes.add(node.name) + const_values.update(out for out in node.output if out) + changed = True + if not changed: + break + + return foldable_nodes + + +def _get_layer_num(node_name: str) -> Optional[int]: + """Return transformer layer index from node name, or None. + + Supports layers.N (Llama/Mistral/Qwen/Gemma/Granite) and h.N (GPT-2). + """ + for part in node_name.split("/"): + if part.startswith("layers."): + suffix = part[len("layers.") :] + if suffix.isdigit(): + return int(suffix) + elif part.startswith("h."): + suffix = part[len("h.") :] + if suffix.isdigit(): + return int(suffix) + return None + + +def _get_inlined_node_map(model) -> tuple: + """Classify ONNX local functions and build inlined sub-node names. + + The compiler inlines a local function body into the parent graph during + ONNX import if it has < 100 nodes AND is not a known custom op + (ONNXModelLoaderSubFuns.cpp). Inlined call-sites do not appear in the + compiler IR; their sub-nodes are named /. + Known custom ops (registered via DEFINEKNOWNCUSTOMOP) keep their + call-site name in the IR and must be included in nodeList as-is. + + Returns: + inlined_node_map: dict mapping call-site name -> list of inlined + sub-node names (/). + non_inlined_funcs: set of function names that are NOT inlined + (known custom ops or >= 100 nodes); their + call-site names are valid nodeList entries. + """ + # Registered with DEFINEKNOWNCUSTOMOP in ONNXModelLoader.cpp + _KNOWN_CUSTOM_OPS = frozenset({"CustomRMSNorm"}) + + local_functions = {f.name: f for f in model.functions} + logger.info(f"Found {len(local_functions)} local function types: {set(local_functions.keys())}") + + inlined_funcs: Set[str] = set() + non_inlined_funcs: Set[str] = set() + for func_name, func in local_functions.items(): + if func_name in _KNOWN_CUSTOM_OPS or len(func.node) >= 100: + non_inlined_funcs.add(func_name) + logger.info(f" {func_name}: not inlined") + else: + inlined_funcs.add(func_name) + logger.info(f" {func_name}: {len(func.node)} nodes, will inline") + + inlined_node_map: Dict[str, List[str]] = {} + for node in model.graph.node: + if node.op_type in inlined_funcs: + func = local_functions[node.op_type] + inlined_node_map[node.name] = [ + f"{node.name}/{fn.name}" for fn in func.node if fn.name + ] + + logger.info(f"Inlined sub-nodes mapped for {len(inlined_node_map)} call-sites") + return inlined_node_map, non_inlined_funcs + + +def generate_disagg_mdp_partition_config( + onnx_path: str, + num_devices: int, + num_partitions: int, + num_layers: int, + num_cores: int = 16, +) -> Dict[str, Any]: + """Generate a pipeline-partitioned MDP config from an exported ONNX graph. + + Assigns nodes to partitions by transformer layer index. Non-layer nodes + (embeddings, lm_head) follow the nearest layer in topological order. + nodeList is a superset of the compiler dump; the compiler silently ignores + optimized-away names. Inlined local function call-sites (CtxScatterCB, + CtxGatherCB) are excluded; their /nNN sub-nodes are assigned automatically. + Known custom ops (CustomRMSNorm) are included by call-site name. + + For PP+TS: num_devices // num_partitions devices per partition; the + compiler applies tensor-slicing within each stage. + + Args: + onnx_path: Path to the exported ONNX file. + num_devices: Total devices (num_partitions * ts_per_stage). + num_partitions: Number of pipeline stages. + num_layers: Number of transformer layers. + num_cores: NSP cores per device (default 16). + + Returns: + dict with keys 'connections' and 'partitions'. + """ + assert num_partitions <= num_devices, f"num_partitions ({num_partitions}) must be <= num_devices ({num_devices})" + + layers_per_partition = num_layers // num_partitions + model = onnx.load(onnx_path, load_external_data=False) + + # Verify topological order (ONNX spec §3.3). Fails loudly on malformed exports. + # Graph inputs and initializers are excluded — they are not produced by any node. + graph_input_names: Set[str] = {inp.name for inp in model.graph.input} + initializer_names: Set[str] = {init.name for init in model.graph.initializer} + external_names: Set[str] = graph_input_names | initializer_names + + output_to_node: Dict[str, str] = {} + for node in model.graph.node: + for out in node.output: + if out: # "" marks optional unused outputs + output_to_node[out] = node.name + + seen_outputs: Set[str] = set() + for node in model.graph.node: + for inp in node.input: + if not inp: + continue + if inp in external_names: + continue + if inp in output_to_node and inp not in seen_outputs: + raise ValueError( + f"ONNX graph has a cycle or violates topological order: " + f"node '{node.name}' consumes '{inp}' produced by " + f"'{output_to_node[inp]}', but that producer has not appeared yet." + ) + for out in node.output: + if out: + seen_outputs.add(out) + + logger.info("Computing constant-foldable nodes...") + folded_nodes = _get_compiler_folded_nodes(model.graph) + logger.info(f"Found {len(folded_nodes)} compiler-folded nodes (excluded from nodeList)") + + inlined_node_map, non_inlined_functions = _get_inlined_node_map(model) + inlined_functions = {f.name for f in model.functions} - non_inlined_functions + + # First pass: assign main graph nodes to partitions by layer index. + partitions: List[List[str]] = [[] for _ in range(num_partitions)] + current_layer_partition = 0 + seen_first_layer = False + max_layer_seen = -1 + + for node in model.graph.node: + if not node.name.startswith("/"): + continue + if node.name in folded_nodes: + continue + if node.op_type in inlined_functions: + continue # inlined; sub-nodes added in second pass + + layer_num = _get_layer_num(node.name) + if layer_num is not None: + max_layer_seen = max(max_layer_seen, layer_num) + seen_first_layer = True + partition_idx = min(layer_num // layers_per_partition, num_partitions - 1) + current_layer_partition = partition_idx + partitions[partition_idx].append(node.name) + else: + if not seen_first_layer: + partitions[0].append(node.name) + else: + partitions[current_layer_partition].append(node.name) + + # Second pass: add inlined sub-nodes, inheriting their call-site's partition. + for call_site_name, inlined_nodes in inlined_node_map.items(): + layer_num = _get_layer_num(call_site_name) + if layer_num is not None: + partition_idx = min(layer_num // layers_per_partition, num_partitions - 1) + else: + partition_idx = current_layer_partition + partitions[partition_idx].extend(inlined_nodes) + + for i, partition in enumerate(partitions): + logger.info(f"Partition {i}: {len(partition)} nodes") + logger.info(f"Total nodes in MDP: {sum(len(p) for p in partitions)}") + + # PP-only: 1 device/partition; PP+TS: num_devices//num_partitions devices/partition. + device_ids = list(range(num_devices)) + devices_per_partition = num_devices // num_partitions + partition_objs = [] + for i, node_list in enumerate(partitions): + assigned_devices = device_ids[i * devices_per_partition : (i + 1) * devices_per_partition] + partition_objs.append( + { + "name": f"Partition{i}", + "nodeList": node_list, + "devices": [{"deviceId": dev_id, "numCores": num_cores} for dev_id in assigned_devices], + } + ) + + return { + "connections": [{"devices": device_ids, "type": "p2p"}], + "partitions": partition_objs, + } From b7717e61738953b0c145b72ea7328969a73fa47e Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Mon, 20 Apr 2026 23:50:47 -0700 Subject: [PATCH 2/3] Formatting and Linting Signed-off-by: Mohit Mehta --- QEfficient/compile/mdp_generator.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/QEfficient/compile/mdp_generator.py b/QEfficient/compile/mdp_generator.py index 58a9ffe71..cbc0f606e 100644 --- a/QEfficient/compile/mdp_generator.py +++ b/QEfficient/compile/mdp_generator.py @@ -6,9 +6,10 @@ # ----------------------------------------------------------------------------- """MDP generator for disaggregated prefill serving (PP-enabled, TS-enabled, stages>1).""" +import logging from typing import Any, Dict, List, Optional, Set + import onnx -import logging logger = logging.getLogger(__name__) @@ -40,7 +41,7 @@ def _get_compiler_folded_nodes(graph) -> Set[str]: # Keep marking nodes foldable until no new ones are found. foldable_nodes: Set[str] = set() - while(True): + while True: changed = False for node in graph.node: if not node.name or node.name in foldable_nodes: @@ -111,9 +112,7 @@ def _get_inlined_node_map(model) -> tuple: for node in model.graph.node: if node.op_type in inlined_funcs: func = local_functions[node.op_type] - inlined_node_map[node.name] = [ - f"{node.name}/{fn.name}" for fn in func.node if fn.name - ] + inlined_node_map[node.name] = [f"{node.name}/{fn.name}" for fn in func.node if fn.name] logger.info(f"Inlined sub-nodes mapped for {len(inlined_node_map)} call-sites") return inlined_node_map, non_inlined_funcs From f393d6e548fe09b3f56a73ef93190ecda3c41b5c Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Wed, 22 Apr 2026 01:41:18 -0700 Subject: [PATCH 3/3] Add compiler options - 'stages' Signed-off-by: Mohit Mehta --- QEfficient/base/modeling_qeff.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index c6e0cfcac..88d7a4056 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -565,7 +565,8 @@ def _compile( if mdp_ts_json_path: command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") mdp_ts_json = load_json(str(mdp_ts_json_path)) - elif mdp_num_partitions > 1: + elif mdp_num_partitions > 1 or "stages" in compiler_options: + mdp_num_partitions = compiler_options.pop("stages", mdp_num_partitions) # Disaggregated (pipeline-parallel) MDP: generate a fully-populated # nodeList per partition directly from the ONNX graph — no compiler # round-trip required.