From a7f99f62b3c3247bef3b39f516615e73aeebca41 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Fri, 10 Apr 2026 12:06:56 +0530 Subject: [PATCH 1/2] revert(export): restore default split and FP16 ONNX transforms Revert proxy-only ONNX transform gating and restore the previous default export behavior with SplitTensorsTransform and FP16ClipTransform enabled by default. Keep the quickcheck workflow and quickcheck test path unchanged. Signed-off-by: vbaddi --- QEfficient/base/modeling_qeff.py | 29 +++---- QEfficient/base/onnx_transforms.py | 19 ++--- QEfficient/transformers/modeling_utils.py | 32 +------- .../transformers/models/modeling_auto.py | 55 +++++++++---- .../unit_test/models/test_model_quickcheck.py | 78 +------------------ 5 files changed, 61 insertions(+), 152 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6f22e867ef..9ae6057d7c 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -20,9 +20,7 @@ from QEfficient.base.onnx_transforms import ( BaseOnnxTransform, - FP16ClipTransform, OnnxTransformPipeline, - SplitTensorsTransform, ) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile @@ -54,8 +52,9 @@ class QEFFBaseModel(ABC): _pytorch_transforms: List[PytorchTransform] _onnx_transforms = [BaseOnnxTransform] - def _transform_names(self) -> List[str]: - return [x.__name__ for x in self._pytorch_transforms + self._onnx_transforms] + @classmethod + def _transform_names(cls) -> List[str]: + return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms] def __init__(self, model: torch.nn.Module, **kwargs) -> None: super().__init__() @@ -246,7 +245,10 @@ def _export( # check if the model is in meta state or weights are offloaded self._model_offloaded_check() - export_dir.mkdir(parents=True, exist_ok=True) + # Setup temporary paths + tmp_onnx_dir = export_dir / "onnx_tmp" + tmp_onnx_path = tmp_onnx_dir / f"{self.model_name}.onnx" + tmp_onnx_dir.mkdir(parents=True, exist_ok=True) # Create input_names from example_inputs input_names = [] @@ -276,7 +278,7 @@ def _export( torch.onnx.export( self.model, (example_inputs,), - str(onnx_path), + str(tmp_onnx_path), input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, @@ -285,13 +287,11 @@ def _export( ) logger.info("PyTorch export successful") _ = self._offload_model_weights(offload_pt_weights) - model = onnx.load(onnx_path, load_external_data=False) + model = onnx.load(tmp_onnx_path, load_external_data=False) - needs_external_tensor_data = any( - transform in self._onnx_transforms for transform in (FP16ClipTransform, SplitTensorsTransform) - ) + # Clear temporary references transform_kwargs = { - "onnx_base_dir": str(export_dir) if needs_external_tensor_data else None, + "onnx_base_dir": str(tmp_onnx_dir), "model_name": self.model_name, } if onnx_transform_kwargs is not None: @@ -306,9 +306,7 @@ def _export( ) logger.info("ONNX transforms applied") - onnx_path_tmp = onnx_path.with_suffix(onnx_path.suffix + ".tmp") - onnx.save(model, onnx_path_tmp) - onnx_path_tmp.replace(onnx_path) + onnx.save(model, onnx_path) del model gc.collect() logger.info("Transformed ONNX saved") @@ -317,6 +315,9 @@ def _export( logger.error(f"ONNX export or transforms failed: {e}") raise e + finally: + shutil.rmtree(tmp_onnx_dir, ignore_errors=True) + self.onnx_path = onnx_path return onnx_path diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 2ba53829a4..16697cec9b 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -7,6 +7,7 @@ import logging import os +import warnings from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict, List, Optional, Tuple, Type @@ -105,27 +106,16 @@ class CustomOpTransform(BaseOnnxTransform): @classmethod def apply(cls, model: ModelProto) -> bool: op_applied = False - - # Register with PyTorch ONNX exporter (for export time) for op_name, (func_class, _) in cls._custom_ops.items(): if hasattr(func_class, "symbolic"): torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, ONNX_EXPORT_OPSET) - used_op_types = {node.op_type for node in model.graph.node} - for function_proto in model.functions: - used_op_types.update(node.op_type for node in function_proto.node) - - # Add function prototypes to model existing = {f.name for f in model.functions} - - for func_name, onnxscript_func in cls._custom_ops.values(): + for _, onnxscript_func in cls._custom_ops.values(): proto = onnxscript_func.to_function_proto() - if proto.name not in used_op_types: - continue if proto.name not in existing: model.functions.append(proto) op_applied = True - return op_applied @@ -212,6 +202,8 @@ class OnnxTransformPipeline(BaseOnnxTransform): """Pipeline to apply multiple ONNX transformations in sequence.""" def __init__(self, transforms: List[Type[BaseOnnxTransform]]): + if not transforms: + warnings.warn("Transform list is empty. No transformations will be applied.") self.transforms = transforms def apply( @@ -236,8 +228,7 @@ def apply( do_split = SplitTensorsTransform in requested fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max file_num_tracker = {"num": 0, "size": 0} - if onnx_base_dir is not None: - external_data_helper.load_external_data_for_model(model, onnx_base_dir) + external_data_helper.load_external_data_for_model(model, onnx_base_dir) if do_fp16 or do_split: for tensor in external_data_helper._get_all_tensors(model): diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index a29d0e0966..162bfa74d6 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- from collections import namedtuple -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type +from typing import Dict, Optional, Tuple, Type import torch import torch.nn as nn @@ -88,14 +88,8 @@ WhisperPositionalEmbedding, ) -from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.customop import CustomRMSNormAIC -from QEfficient.proxy.pytorch_transform import QeffProxyModuleTransform from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE -from QEfficient.utils.logging_utils import logger - -if TYPE_CHECKING: - from QEfficient.base.modeling_qeff import QEFFBaseModel # Placeholder for all non-transformer models from .models.codegen.modeling_codegen import ( @@ -197,30 +191,6 @@ # This is for supporting different modelling classes specially written for prefill-only model SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss"} - -_PROXY_ONLY_ONNX_TRANSFORMS = (FP16ClipTransform, SplitTensorsTransform) - - -def _configure_proxy_for_model(instance: "QEFFBaseModel", enable_proxy: bool) -> None: - """ - Configure per-instance transform lists based on proxy mode. - - Keep class-defined ONNX transforms by default. - Proxy flow appends additional proxy-only transforms. - """ - instance._pytorch_transforms = list(instance._pytorch_transforms) - instance._onnx_transforms = list(instance._onnx_transforms) - instance._enable_proxy = enable_proxy - - if enable_proxy: - if QeffProxyModuleTransform not in instance._pytorch_transforms: - instance._pytorch_transforms.append(QeffProxyModuleTransform) - for transform in _PROXY_ONLY_ONNX_TRANSFORMS: - if transform not in instance._onnx_transforms: - instance._onnx_transforms.append(transform) - logger.info("Proxy Model Enabled for QEfficient Model") - - # Define a transformers layers to QEff layers dictionary # While onboarding new models make sure to add the new layer maps to this dictionary. TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = { diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2f8e971e34..ec6ab84fb2 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -29,7 +29,7 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel -from QEfficient.base.onnx_transforms import FP16ClipTransform +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import ( @@ -40,10 +40,10 @@ write_io_files, ) from QEfficient.generation.vlm_generation import VisionLanguageGeneration +from QEfficient.proxy.pytorch_transform import QeffProxyModuleTransform from QEfficient.transformers.modeling_utils import ( DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH, SPECIALIZED_DISAGG_SERVING_MODEL_ARCH, - _configure_proxy_for_model, ) from QEfficient.transformers.models.pytorch_transforms import ( BlockedKVAttentionTransform, @@ -91,7 +91,9 @@ class QEFFTransformersBase(QEFFBaseModel): _hf_auto_class: type def __init__(self, model: nn.Module, **kwargs) -> None: - _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") if ( hasattr(model, "config") @@ -231,7 +233,7 @@ class QEFFAutoModel(QEFFTransformersBase): _hf_auto_class = AutoModel _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform] - _onnx_transforms = [FP16ClipTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.Module, pooling=None, **kwargs): """ @@ -248,6 +250,10 @@ def __init__(self, model: nn.Module, pooling=None, **kwargs): **kwargs : Additional keyword arguments passed to the base class constructor. """ + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") + super().__init__(model, **kwargs) # Make Embedding specific transforms like appending pooling @@ -619,7 +625,7 @@ class QEFFAutoModelForSequenceClassification(QEFFTransformersBase): _hf_auto_class = AutoModelForSequenceClassification _pytorch_transforms = [CustomOpsTransform, TextClassificationTransform] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.Module, **kwargs): """ @@ -662,8 +668,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): QEFFAutoModelForSequenceClassification An instance initialized with the pretrained weights. """ - enable_proxy = kwargs.pop("enable_proxy", False) - if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') @@ -673,7 +677,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) @property @@ -861,7 +864,7 @@ class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel): KVCacheTransform, KVCacheExternalModuleMapperTransform, ] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.modules, **kwargs): """ @@ -874,7 +877,9 @@ def __init__(self, model: nn.modules, **kwargs): **kwargs : Additional keyword arguments passed to the base class constructor. """ - _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") super().__init__(model, **kwargs) self.model = model.get_qeff_vision_encoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ @@ -1002,7 +1007,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): VlmKVOffloadTransform, SplitGateUpWeightsTransform, ] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): """ @@ -1018,7 +1023,9 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): **kwargs : Additional keyword arguments passed to the base class constructor. """ - _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.model.qaic_config = qaic_config @@ -1937,7 +1944,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal VlmNoKVOffloadTransform, SplitGateUpWeightsTransform, ] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__( self, @@ -1971,6 +1978,10 @@ def __init__( if qaic_config is not None and qaic_config.pop("include_sampler", False): raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") + super().__init__(model, **kwargs) self.model.qaic_config = qaic_config @@ -2689,7 +2700,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): KVCacheExternalModuleMapperTransform, ] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def prefill( self, @@ -2766,7 +2777,9 @@ def __init__( model_class_name = model.__class__.__name__ if not (model_class_name.endswith("ForCausalLM") or model_class_name.endswith("LMHeadModel")): raise TypeError(f"Required pytorch module for CausalLM or LMHeadModel, got {model_class_name}") - _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") # TODO: remove from version 1.20 if kwargs.pop("full_batch_size", None): @@ -3654,7 +3667,7 @@ class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin _hf_auto_class = AutoModelForSpeechSeq2Seq _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, KVCacheTransform] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.Module, **kwargs): """ @@ -3674,6 +3687,10 @@ def __init__(self, model: nn.Module, **kwargs): """ model_class_name = model.__class__.__name__ + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") + if not (model_class_name.endswith("ForConditionalGeneration")): raise TypeError(f"Required pytorch module with ForConditionalGeneration, got {model_class_name}") @@ -4013,9 +4030,13 @@ class QEFFAutoModelForCTC(QEFFTransformersBase): _hf_auto_class = AutoModelForCTC _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.Module, **kwargs): + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") + super().__init__(model, **kwargs) self.model.base_model.config.use_cache = True diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index c0b5c20525..312fdea6b1 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -25,7 +25,7 @@ from contextlib import contextmanager, redirect_stderr, redirect_stdout from io import StringIO from pathlib import Path -from typing import Dict, Optional, Set +from typing import Dict import numpy as np import onnx @@ -213,21 +213,6 @@ def _run_whisper_export_smoke(qeff_model: QEFFAutoModelForSpeechSeq2Seq, out_dir return onnx_path -def _assert_proxy_only_onnx_transform_policy( - qeff_model, enable_proxy: bool, always_on_transforms: Optional[Set[str]] = None -) -> None: - transform_names = {transform.__name__ for transform in qeff_model._onnx_transforms} - proxy_only_transforms = {"FP16ClipTransform", "SplitTensorsTransform"} - always_on_transforms = always_on_transforms or set() - conditional_proxy_transforms = proxy_only_transforms - always_on_transforms - - if enable_proxy: - assert proxy_only_transforms.issubset(transform_names) - else: - assert conditional_proxy_transforms.isdisjoint(transform_names) - assert always_on_transforms.issubset(transform_names) - - def _skip_on_model_fetch_error(exc: Exception, model_id: str) -> None: pytest.skip( f"Skipping {model_id}: model unavailable or unsupported in this environment ({type(exc).__name__}: {exc})" @@ -394,7 +379,7 @@ def test_text_embedding_fp16_clip_transform_and_export(tmp_path): transform_names = {transform.__name__ for transform in qeff_model._onnx_transforms} assert "FP16ClipTransform" in transform_names - assert "SplitTensorsTransform" not in transform_names + assert "SplitTensorsTransform" in transform_names inputs = tokenizer("hello world", return_tensors="pt") onnx_path = _exported_onnx_path(qeff_model.export(tmp_path / "embedding-ai100")) @@ -607,7 +592,6 @@ def test_causal_subfunction_and_proxy_export_smoke_gpt2(tmp_path): except Exception as exc: _skip_on_model_fetch_error(exc, model_id) - _assert_proxy_only_onnx_transform_policy(qeff_model, enable_proxy=True) onnx_path = _exported_onnx_path( qeff_model.export(tmp_path / "with-subfunctions-and-proxy", use_onnx_subfunctions=True) ) @@ -642,61 +626,3 @@ def test_awq_export_smoke(tmp_path): onnx_model = onnx.load(onnx_path, load_external_data=False) assert any(node.op_type == "MatMulNBits" for node in onnx_model.graph.node) - - -@pytest.mark.llm_model -def test_proxy_toggle_onnx_transform_policy_for_causal_lm(): - model_id = CAUSAL_RUNTIME_MODEL_IDS["gpt2"] - try: - qeff_default = QEFFAutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) - qeff_proxy = QEFFAutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, enable_proxy=True) - except Exception as exc: - _skip_on_model_fetch_error(exc, model_id) - - _assert_proxy_only_onnx_transform_policy(qeff_default, enable_proxy=False) - _assert_proxy_only_onnx_transform_policy(qeff_proxy, enable_proxy=True) - - -@pytest.mark.llm_model -def test_proxy_toggle_onnx_transform_policy_for_embedding(): - model_id = TINY_TEXT_EMBEDDING_MODEL_ID - try: - qeff_default = QEFFAutoModel.from_pretrained(model_id) - qeff_proxy = QEFFAutoModel.from_pretrained(model_id, enable_proxy=True) - except Exception as exc: - _skip_on_model_fetch_error(exc, model_id) - - _assert_proxy_only_onnx_transform_policy( - qeff_default, enable_proxy=False, always_on_transforms={"FP16ClipTransform"} - ) - _assert_proxy_only_onnx_transform_policy(qeff_proxy, enable_proxy=True) - - -@pytest.mark.llm_model -def test_proxy_toggle_onnx_transform_policy_for_whisper(): - model_id = TINY_WHISPER_MODEL_ID - try: - qeff_default = QEFFAutoModelForSpeechSeq2Seq.from_pretrained(model_id, trust_remote_code=True) - qeff_proxy = QEFFAutoModelForSpeechSeq2Seq.from_pretrained(model_id, trust_remote_code=True, enable_proxy=True) - except Exception as exc: - _skip_on_model_fetch_error(exc, model_id) - - _assert_proxy_only_onnx_transform_policy(qeff_default, enable_proxy=False) - _assert_proxy_only_onnx_transform_policy(qeff_proxy, enable_proxy=True) - - -@pytest.mark.llm_model -def test_proxy_toggle_onnx_transform_policy_for_vlm(): - model_id = VLM_TEXT_RUNTIME_MODEL_ID - try: - qeff_default = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, trust_remote_code=True, kv_offload=False - ) - qeff_proxy = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, trust_remote_code=True, enable_proxy=True, kv_offload=False - ) - except Exception as exc: - _skip_on_model_fetch_error(exc, model_id) - - _assert_proxy_only_onnx_transform_policy(qeff_default, enable_proxy=False) - _assert_proxy_only_onnx_transform_policy(qeff_proxy, enable_proxy=True) From cec8c60f246d729d71b61dcc6759e6fd6560be81 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Fri, 10 Apr 2026 12:22:05 +0530 Subject: [PATCH 2/2] nit: Adds only the function prototypes actually used by the exported graph (customop) Signed-off-by: vbaddi --- QEfficient/base/onnx_transforms.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 16697cec9b..247edfc2d2 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -106,13 +106,22 @@ class CustomOpTransform(BaseOnnxTransform): @classmethod def apply(cls, model: ModelProto) -> bool: op_applied = False + + # Register with PyTorch ONNX exporter (for export time) for op_name, (func_class, _) in cls._custom_ops.items(): if hasattr(func_class, "symbolic"): torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, ONNX_EXPORT_OPSET) + used_op_types = {node.op_type for node in model.graph.node} + for function_proto in model.functions: + used_op_types.update(node.op_type for node in function_proto.node) + + # Add function prototypes to model existing = {f.name for f in model.functions} for _, onnxscript_func in cls._custom_ops.values(): proto = onnxscript_func.to_function_proto() + if proto.name not in used_op_types: + continue if proto.name not in existing: model.functions.append(proto) op_applied = True