From c2f216ad5fa8e68c31e5308c5f3bfba2d9741bb4 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Sun, 1 Feb 2026 22:43:57 +0000 Subject: [PATCH 1/8] Integrate Automated QDQ placement tool - part 3.2 Signed-off-by: Will Guo --- .../onnx/quantization/autotune/__init__.py | 62 + .../onnx/quantization/autotune/autotuner.py | 1092 +++++++++++++++++ .../autotune/autotune/test_autotuner.py | 345 ++++++ 3 files changed, 1499 insertions(+) create mode 100644 modelopt/onnx/quantization/autotune/__init__.py create mode 100644 modelopt/onnx/quantization/autotune/autotuner.py create mode 100644 tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py diff --git a/modelopt/onnx/quantization/autotune/__init__.py b/modelopt/onnx/quantization/autotune/__init__.py new file mode 100644 index 000000000..eee4baa46 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/__init__.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pattern-Based Q/DQ Autotuning for ONNX Models. + +This package provides automated optimization of Quantize/Dequantize (Q/DQ) node placement +in ONNX computation graphs to minimize TensorRT inference latency. It uses pattern-based +region analysis to efficiently explore and optimize Q/DQ insertion strategies. +""" + +# Core data structures +from .autotuner import QDQAutotuner +from .common import ( + AutotunerError, + AutotunerNotInitializedError, + Config, + InsertionScheme, + InvalidSchemeError, + PatternCache, + PatternSchemes, + Region, + RegionType, +) +from .insertion_points import ( + ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, + NodeInputInsertionPoint, + ResolvedInsertionPoint, +) +from .region_pattern import RegionPattern +from .region_search import CombinedRegionSearch + +__all__ = [ + "AutotunerError", + "AutotunerNotInitializedError", + "ChildRegionInputInsertionPoint", + "ChildRegionOutputInsertionPoint", + "CombinedRegionSearch", + "Config", + "InsertionScheme", + "InvalidSchemeError", + "NodeInputInsertionPoint", + "PatternCache", + "PatternSchemes", + "QDQAutotuner", + "Region", + "RegionPattern", + "RegionType", + "ResolvedInsertionPoint", +] diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py new file mode 100644 index 000000000..9eb8724dc --- /dev/null +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -0,0 +1,1092 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Automatic Q/DQ insertion optimization for ONNX models via pattern-based profiling.""" + +import copy +import os +import random +from collections import deque +from datetime import datetime, timezone + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import yaml + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import ( + AutotunerNotInitializedError, + Config, + InsertionScheme, + InvalidSchemeError, + PatternCache, + PatternSchemes, + Region, + RegionType, +) +from modelopt.onnx.quantization.autotune.insertion_points import ( + ResolvedInsertionPoint, + merge_resolved_insertion_points, +) +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern +from modelopt.onnx.quantization.autotune.region_search import CombinedRegionSearch +from modelopt.onnx.quantization.fp8 import int8_to_fp8 +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + + +class QDQAutotunerBase: + """Base class for pattern-based Q/DQ node insertion optimization in ONNX models.""" + + def __init__(self, model: onnx.ModelProto | gs.Graph): + """Initialize the autotuner with an ONNX model.""" + if isinstance(model, onnx.ModelProto): + self.onnx_model = model + elif isinstance(model, gs.Graph): + self.onnx_model = gs.export_onnx(model) + else: + raise TypeError(f"Expected onnx.ModelProto or gs.Graph, got {type(model)}") + + self.graph = self._copy_graph() + self.graph.tensor_users_map = get_tensor_consumer_node_indices(self.graph) + self.regions: list[Region] = [] + self.current_profile_region: Region | None = None + self.profiled_patterns: list[PatternSchemes] = [] + self.current_profile_pattern_schemes: PatternSchemes | None = None + self.current_insertion_scheme_index: int | None = None + self.config = Config() + self.initialized = False + self.baseline_latency_ms: float | None = None + self.pattern_cache: PatternCache | None = None + + logger.debug(f"Initialized autotuner with model type: {type(model).__name__}") + + def initialize( + self, config: Config | None = None, pattern_cache: PatternCache | None = None + ) -> None: + """Initialize autotuning session with configuration and pattern cache.""" + if config is not None: + self.config = config + + if pattern_cache is None: + pattern_cache = PatternCache( + minimum_distance=self.config.pattern_cache_minimum_distance, + max_entries_per_pattern=self.config.pattern_cache_max_entries_per_pattern, + ) + self.pattern_cache = pattern_cache + + logger.debug( + f"Loaded pattern cache with {pattern_cache.num_patterns} patterns and " + f"{pattern_cache.total_schemes} schemes" + ) + + self.initialized = False + self.baseline_latency_ms = None + self.profiled_patterns.clear() + self.regions.clear() + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + + logger.info("Initializing autotuner") + logger.debug( + f"Configuration: q_scale={self.config.default_q_scale}, " + f"q_zero_point={self.config.default_q_zero_point}, quant_type={self.config.default_quant_type}" + ) + + self.initialized = True + + def set_profile_region(self, region: Region | None, commit: bool = True) -> None: + """Set the target region for profiling and scheme generation.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + if commit: + if self.current_profile_pattern_schemes is not None: + num_schemes = len(self.current_profile_pattern_schemes.schemes) + best_scheme = self.current_profile_pattern_schemes.best_scheme + best_latency = best_scheme.latency_ms if best_scheme else float("inf") + + samples_before_best, time_to_best = self._compute_convergence_metrics( + self.current_profile_pattern_schemes.schemes, best_scheme + ) + + logger.info( + f"Pattern complete: {num_schemes} schemes tested, best latency {best_latency:.3f} ms" + ) + logger.debug( + f"Pattern signature: {self.current_profile_pattern_schemes.pattern_signature}" + ) + if samples_before_best is not None: + logger.debug(f"Convergence: best found at sample {samples_before_best}") + if time_to_best is not None: + logger.debug(f"Time to best: {time_to_best:.2f}s") + self.profiled_patterns.append(self.current_profile_pattern_schemes) + + if commit or region is None: + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + if region is None: + return + + if region not in self.regions: + raise ValueError(f"Region {region.id} not found in regions") + + region_pattern = RegionPattern.from_region(region, self.graph) + + if self._is_region_profiled(region): + logger.info(f"Skipping region {region.id} (pattern already profiled)") + logger.debug(f"Pattern signature: {region_pattern.signature}") + return + + pattern_schemes = None + num_seeded = 0 + + if self.pattern_cache is not None: + cache_schemes = self.pattern_cache.get_pattern_schemes(region_pattern.signature) + + if cache_schemes is not None and len(cache_schemes.schemes) > 0: + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = region_pattern + + for cached_scheme in cache_schemes.schemes: + scheme_copy = copy.deepcopy(cached_scheme) + scheme_copy.latency_ms = float("inf") + scheme_copy.error = False + pattern_schemes.schemes.append(scheme_copy) + num_seeded += 1 + + logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") + else: + logger.debug("No pattern cache entries for this region") + + if pattern_schemes is None: + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = region_pattern + logger.debug("Initialized with empty scheme collection") + + self.current_profile_region = region + self.current_profile_pattern_schemes = pattern_schemes + + mode_info = f"seeded with {num_seeded} schemes" if num_seeded > 0 else "starting fresh" + logger.info( + f"Profiling region {region.id} [pattern mode, level {region.level}, " + f"size {region.get_size_of_region_and_descendants()}, {mode_info}]" + ) + logger.debug(f"Pattern signature: {region_pattern.signature}") + + def generate(self) -> int: + """Generate a new Q/DQ insertion scheme for the current pattern or region.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + elif self.current_profile_pattern_schemes is None: + raise InvalidSchemeError("No region selected. Call set_profile_region() first.") + + pattern_schemes = self.current_profile_pattern_schemes + cached_schemes = [ + (idx, scheme) + for idx, scheme in enumerate(pattern_schemes.schemes) + if not scheme.is_profiled + ] + + if cached_schemes: + scheme_index, cached_scheme_data = cached_schemes[0] + num_node_points = len(cached_scheme_data.node_inputs) + num_region_composite_points = len(cached_scheme_data.child_region_inputs) + num_region_output_points = len(cached_scheme_data.region_outputs) + total_points = num_node_points + num_region_composite_points + num_region_output_points + + logger.info( + f"Scheme #{scheme_index + 1}: profiling cached scheme ({total_points} Q/DQ points)" + ) + logger.debug( + f"Cached scheme breakdown: {num_node_points} node input, " + f"{num_region_composite_points} region composite, " + f"{num_region_output_points} region output points ({len(cached_schemes)} cached schemes remaining)" + ) + + self.current_insertion_scheme_index = scheme_index + return self.current_insertion_scheme_index + + known_schemes = {scheme.hash for scheme in pattern_schemes.schemes} + max_attempts = getattr(self.config, "maximum_generation_attempts", 100) + + logger.debug(f"Generating new scheme ({len(pattern_schemes.schemes)} schemes exist)") + + for attempts in range(max_attempts): + new_scheme = self._generate_next_insertion_sample() + if new_scheme.hash not in known_schemes and not new_scheme.error: + pattern_schemes.schemes.append(new_scheme) + scheme_index = len(pattern_schemes.schemes) - 1 + num_node_points = len(new_scheme.node_inputs) + num_region_composite_points = len(new_scheme.child_region_inputs) + num_region_output_points = len(new_scheme.region_outputs) + total_points = ( + num_node_points + num_region_composite_points + num_region_output_points + ) + + logger.info( + f"Scheme #{scheme_index + 1}: generated new scheme ({total_points} Q/DQ points)" + ) + logger.debug( + f"Scheme breakdown: {num_node_points} node input, " + f"{num_region_composite_points} region composite, " + f"{num_region_output_points} region output points " + f"(hash: {new_scheme.hash[:16]}..., attempts: {attempts + 1})" + ) + + self.current_insertion_scheme_index = scheme_index + return self.current_insertion_scheme_index + + logger.warning(f"Could not generate unique scheme after {max_attempts} attempts") + return -1 + + def export_onnx( + self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False + ) -> bytes: + """Export ONNX model with Q/DQ nodes inserted according to tested schemes.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + output_desc = output_path if output_path is not None else "" + original_quant_type = self.config.default_quant_type + needs_fp8_conversion = insert_qdq and original_quant_type == "fp8" + resolved_insertion_points = set() + + logger.debug( + f"Exporting model to {output_desc} (insert_qdq={insert_qdq}, " + f"regions={len(self.regions)}, profiled_patterns={len(self.profiled_patterns)})" + ) + + if needs_fp8_conversion: + logger.debug("FP8 conversion: creating INT8 model first") + self.config.default_quant_type = "int8" + + if insert_qdq: + matched_regions = 0 + + logger.debug(f"Resolving Q/DQ insertion points from {len(self.regions)} regions") + + for region in self.regions: + pattern = RegionPattern.from_region(region, self.graph) + logger.debug(f"Region {region.id} (level {region.level})") + logger.debug(f" → Pattern signature: {pattern.signature}") + + matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None) + current_scheme = matched.best_scheme if matched else None + + if matched: + if current_scheme: + logger.debug( + f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)" + ) + else: + logger.debug(" → Matched profiled pattern but no valid schemes") + + if current_scheme is None: + current_scheme = self.current_profile_pattern_schemes + if current_scheme is None or pattern != current_scheme.pattern: + pass + elif best: + current_scheme = current_scheme.best_scheme + else: + scheme_index = self.current_insertion_scheme_index + if scheme_index is not None: + assert scheme_index < len(current_scheme.schemes), ( + f"Invalid scheme index: {scheme_index}" + ) + current_scheme = current_scheme.schemes[scheme_index] + logger.debug(f" → Using current pattern scheme #{scheme_index}") + + if current_scheme is None and self.pattern_cache is not None: + pattern_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) + if pattern_schemes is not None: + schemes = pattern_schemes.schemes + if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled: + current_scheme = schemes[0] + logger.debug(" → Using imported pattern from cache") + + if current_scheme is None: + logger.debug(" → No scheme available, skipping") + continue + + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + assert full_insertion_scheme is not None + all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) + assert isinstance(all_region_ips, set) + resolved_insertion_points.difference_update(all_region_ips) + excluded_tensors = all_region_ips - resolved_insertion_points + if excluded_tensors: + logger.debug( + f" → Excluded {len(excluded_tensors)} overlapping insertion points" + ) + + new_ips = pattern.matches(region, self.graph, current_scheme) + if new_ips: + resolved_insertion_points.update(new_ips) + matched_regions += 1 + logger.debug(f" → Added {len(new_ips)} insertion points") + + logger.debug( + f"Matched {matched_regions}/{len(self.regions)} regions, " + f"total {len(resolved_insertion_points)} unique insertion points" + ) + + graph_copy = self._copy_graph() + unique_tensors = len(resolved_insertion_points) + + logger.debug(f"Inserting {unique_tensors} Q/DQ pairs into graph") + + if insert_qdq and resolved_insertion_points: + self._insert_qdq_at_tensors(graph_copy, resolved_insertion_points) + + logger.debug("Serializing to ONNX format") + model = gs.export_onnx(graph_copy) + + if insert_qdq and resolved_insertion_points: + self._fix_zero_point_initializers(model) + + if needs_fp8_conversion: + logger.debug("Converting INT8 to FP8") + model = int8_to_fp8(model) + + self.config.default_quant_type = original_quant_type + model_bytes = model.SerializeToString() + quant_type_str = "baseline" + output_dest = "" + + if insert_qdq: + quant_type_str = f"{original_quant_type.upper()}" if needs_fp8_conversion else "INT8" + + if output_path is not None: + onnx.save(model, output_path) + output_dest = f" → {output_path}" + + logger.info( + f"Exported {quant_type_str} model with {unique_tensors} Q/DQ pairs {output_dest}" + ) + return model_bytes + + def submit(self, latency_ms: float, success: bool = True) -> None: + """Submit performance measurement for the most recently generated scheme.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + if self.baseline_latency_ms is None: + self.baseline_latency_ms = latency_ms + logger.info(f"Baseline latency: {latency_ms:.3f} ms") + return + + if self.current_profile_pattern_schemes is None: + raise InvalidSchemeError( + "No pattern or region selected. Call set_profile_region() first." + ) + + schemes_collection = self.current_profile_pattern_schemes + if not schemes_collection.schemes: + raise InvalidSchemeError("No schemes available. Call generate() first.") + + pattern_schemes = schemes_collection + + if self.current_insertion_scheme_index is not None: + scheme_index = self.current_insertion_scheme_index + if scheme_index >= len(pattern_schemes.schemes): + raise InvalidSchemeError(f"Invalid scheme index: {scheme_index}") + scheme = pattern_schemes.schemes[scheme_index] + else: + scheme = pattern_schemes.schemes[-1] + scheme_index = len(pattern_schemes.schemes) - 1 + + scheme.latency_ms = latency_ms + scheme.error = not success + scheme.profile_timestamp = datetime.now(timezone.utc).isoformat() + display_index = scheme_index + 1 + + if not success: + logger.warning( + f"Scheme #{display_index}: measurement failed (latency={latency_ms:.3f} ms)" + ) + logger.debug("Marking scheme with error flag") + return + + speedup = self.baseline_latency_ms / latency_ms if latency_ms > 0 else 0.0 + + logger.info(f"Scheme #{display_index}: {latency_ms:.3f} ms ({speedup:.2f}x speedup)") + logger.debug(f"Compared to baseline: {self.baseline_latency_ms:.3f} ms") + + old_best = ( + pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") + ) + pattern_schemes.schemes.sort( + key=lambda s: s.latency_ms if s.latency_ms > 0 else float("inf") + ) + new_best = ( + pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") + ) + + if new_best < old_best: + new_speedup = self.baseline_latency_ms / new_best if new_best > 0 else 0.0 + logger.info(f" ★ New best: {new_best:.3f} ms ({new_speedup:.2f}x speedup)") + logger.debug(f"Previous best: {old_best:.3f} ms") + + if self.current_profile_pattern_schemes is not None and self.pattern_cache is not None: + self.pattern_cache.add_pattern_schemes(pattern_schemes) + logger.debug( + f"Pattern cache updated: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def save_state(self, output_path: str) -> None: + """Save complete autotuner state to a YAML file for later reuse.""" + current_pattern_sig = None + if self.current_profile_pattern_schemes is not None: + current_pattern_sig = self.current_profile_pattern_schemes.pattern_signature + + state = { + "baseline_latency_ms": self.baseline_latency_ms, + "current_profile_pattern_schemes_signature": current_pattern_sig, + "config": { + "default_q_scale": self.config.default_q_scale, + "default_q_zero_point": self.config.default_q_zero_point, + "default_quant_type": self.config.default_quant_type, + "verbose": self.config.verbose, + }, + "patterns": [pattern_schemes.to_dict() for pattern_schemes in self.profiled_patterns], + } + + with open(output_path, "w") as f: + yaml.dump(state, f, default_flow_style=False, sort_keys=False) + + num_patterns = len(self.profiled_patterns) + total_schemes = sum(len(p.schemes) for p in self.profiled_patterns) + + logger.info( + f"Saved state → {output_path} ({num_patterns} patterns, {total_schemes} schemes)" + ) + logger.debug(f"State: baseline={self.baseline_latency_ms:.3f} ms") + + if self.pattern_cache is not None and self.pattern_cache.num_patterns > 0: + base_path, ext = os.path.splitext(output_path) + cache_path = f"{base_path}_pattern_cache{ext}" + self.pattern_cache.save(cache_path) + + logger.info(f"Saved pattern cache → {cache_path}") + logger.debug( + f"Cache: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def load_state(self, input_path: str) -> None: + """Load autotuner state from a previously saved YAML file.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + with open(input_path) as f: + state = yaml.safe_load(f) + + if state.get("baseline_latency_ms") is not None: + self.baseline_latency_ms = state["baseline_latency_ms"] + logger.debug(f"Baseline latency: {self.baseline_latency_ms:.3f} ms") + + if "config" in state: + config_data = state["config"] + if "default_q_scale" in config_data: + self.config.default_q_scale = config_data["default_q_scale"] + if "default_q_zero_point" in config_data: + self.config.default_q_zero_point = config_data["default_q_zero_point"] + if "default_quant_type" in config_data: + self.config.default_quant_type = config_data["default_quant_type"] + if "verbose" in config_data: + self.config.verbose = config_data["verbose"] + logger.debug(f"Config merged: quant_type={self.config.default_quant_type}") + + if "patterns" in state: + num_loaded_patterns = 0 + num_loaded_schemes = 0 + + for pattern_data in state["patterns"]: + try: + pattern_schemes = PatternSchemes.from_dict(pattern_data) + + if pattern_schemes.schemes: + self.profiled_patterns.append(pattern_schemes) + num_loaded_patterns += 1 + num_loaded_schemes += len(pattern_schemes.schemes) + else: + logger.debug( + f"Skipped empty pattern {pattern_schemes.pattern_signature[:16]}..." + ) + + except Exception as e: # noqa: PERF203 + logger.warning(f"Failed to load pattern: {e}") + continue + + logger.info( + f"Loaded state from {input_path} ({num_loaded_patterns} patterns, " + f"{num_loaded_schemes} schemes)" + ) + + base_path, ext = os.path.splitext(input_path) + cache_path = f"{base_path}_pattern_cache{ext}" + + if os.path.exists(cache_path): + try: + loaded_cache = PatternCache.load(cache_path) + + if self.pattern_cache is not None: + for pattern_schemes in loaded_cache.pattern_schemes: + self.pattern_cache.add_pattern_schemes(pattern_schemes) + else: + self.pattern_cache = loaded_cache + logger.info( + f"Loaded pattern cache from {cache_path} ({loaded_cache.num_patterns} patterns, " + f"{loaded_cache.total_schemes} schemes)" + ) + except Exception as e: + logger.warning(f"Failed to load pattern cache: {e}") + else: + logger.debug(f"No pattern cache file at {cache_path}") + + def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None: + """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + if isinstance(quantized_tensors, list): + quantized_tensors = set(quantized_tensors) + + logger.info(f"Importing insertion points from {len(quantized_tensors)} quantized tensors") + logger.debug(f"Processing {len(self.regions)} regions") + + if self.pattern_cache is None: + logger.warning("Pattern cache not initialized, skipping import") + return + + patterns_before = self.pattern_cache.num_patterns + schemes_before = self.pattern_cache.total_schemes + + for region in self.regions: + self.pattern_cache.add_pattern_from_region(region, self.graph, quantized_tensors) + + patterns_added = self.pattern_cache.num_patterns - patterns_before + schemes_added = self.pattern_cache.total_schemes - schemes_before + + logger.info( + f"Import complete: {patterns_added} patterns, {schemes_added} schemes added to cache" + ) + logger.debug( + f"Total cache: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def _compute_convergence_metrics( + self, schemes: list[InsertionScheme], best_scheme: InsertionScheme | None + ) -> tuple[int | None, float | None]: + """Compute convergence metrics for a collection of schemes.""" + samples_before_best = None + time_to_best = None + + if not best_scheme or not best_scheme.profile_timestamp: + return samples_before_best, time_to_best + + schemes_with_time = [s for s in schemes if s.profile_timestamp is not None] + + if not schemes_with_time: + return samples_before_best, time_to_best + + schemes_with_time.sort(key=lambda s: s.profile_timestamp or "") + + try: + best_position = next( + i for i, s in enumerate(schemes_with_time) if s.hash == best_scheme.hash + ) + samples_before_best = best_position + + first_ts = schemes_with_time[0].profile_timestamp + best_ts = best_scheme.profile_timestamp + assert first_ts is not None and best_ts is not None + first_timestamp = datetime.fromisoformat(first_ts) + best_timestamp = datetime.fromisoformat(best_ts) + time_to_best = (best_timestamp - first_timestamp).total_seconds() + except (StopIteration, ValueError): + pass + + return samples_before_best, time_to_best + + def _is_region_profiled(self, region: Region) -> bool: + """Check if a region's pattern has already been fully profiled.""" + + def match_pattern(pattern: PatternSchemes, region: Region) -> bool: + """Check if a pattern matches a region.""" + if pattern.pattern is None or not pattern.pattern.matches(region, self.graph): + return False + return not any(not scheme.is_profiled for scheme in pattern.schemes) + + return any(match_pattern(pattern, region) for pattern in self.profiled_patterns) + + def _mutate_insertion_points( + self, base_points, all_points, point_type: str, max_mutations: int + ) -> list: + """Mutate a set of insertion points by adding, removing, or both.""" + key_fn = { + "node input points": lambda p: (p.node_index, p.input_index), + "region composite points": lambda p: (p.region_index, p.input_index), + "region output points": lambda p: (p.region_index, p.node_index, p.output_index), + }.get(point_type) + + if not key_fn: + return [] + + current_points = set(base_points) + initial_count = len(current_points) + mutation_type = random.choice(["add", "remove", "both"]) + + if mutation_type in ["add", "both"] and len(current_points) < len(all_points): + all_keys = {key_fn(p) for p in all_points} + available_keys = all_keys - current_points + if available_keys: + max_add = min(max_mutations, len(available_keys)) + num_to_add = random.randint(1, max_add) + to_add = random.sample(list(available_keys), num_to_add) + current_points.update(to_add) + + if mutation_type in ["remove", "both"] and current_points: + max_remove = min(max_mutations, len(current_points)) + num_to_remove = random.randint(1, max_remove) if len(current_points) > 1 else 1 + num_to_remove = min(num_to_remove, len(current_points)) + to_remove = random.sample(list(current_points), num_to_remove) + for p in to_remove: + current_points.discard(p) + + logger.debug( + f"Mutated {point_type}: {initial_count} → {len(current_points)} ({mutation_type})" + ) + + return [p for p in all_points if key_fn(p) in current_points] + + def _generate_next_insertion_sample(self) -> InsertionScheme: + """Generate a new insertion scheme by mutating top performers.""" + if self.current_profile_region is None: + return InsertionScheme() + + if self.current_profile_pattern_schemes is not None: + schemes_collection = self.current_profile_pattern_schemes + else: + return InsertionScheme() + + region = self.current_profile_region + pattern_schemes = schemes_collection + + if not isinstance(schemes_collection, PatternSchemes) or schemes_collection.pattern is None: + return InsertionScheme() + pattern = schemes_collection.pattern + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + + logger.debug( + f"Available insertion points: {len(full_insertion_scheme.node_inputs)} node input, " + f"{len(full_insertion_scheme.child_region_inputs)} region composite, " + f"{len(full_insertion_scheme.region_outputs)} region output" + ) + + top_percent = getattr(self.config, "top_percent_to_mutate", 0.1) + minimum_schemes = getattr(self.config, "minimum_schemes_to_mutate", 1) + + measured_schemes = [s for s in pattern_schemes.schemes if s.latency_ms > 0 and not s.error] + measured_schemes.sort(key=lambda s: s.latency_ms) + + num_top_schemes = max( + int(len(measured_schemes) * top_percent), min(minimum_schemes, len(measured_schemes)) + ) + top_schemes = measured_schemes[:num_top_schemes] + + if len(top_schemes) == 0: + logger.debug("No measured schemes yet, generating baseline (empty) scheme") + return InsertionScheme() + + base_scheme = random.choice(top_schemes) + total_base_points = ( + len(base_scheme.node_inputs) + + len(base_scheme.child_region_inputs) + + len(base_scheme.region_outputs) + ) + logger.debug( + f"Mutating from top {len(top_schemes)} schemes: " + f"selected base with {total_base_points} points (latency={base_scheme.latency_ms:.3f} ms)" + ) + + max_mutations = getattr(self.config, "maximum_mutations", 3) + + scheme = InsertionScheme() + base_node_points = {(p.node_index, p.input_index) for p in base_scheme.node_inputs} + scheme.node_inputs = self._mutate_insertion_points( + base_node_points, full_insertion_scheme.node_inputs, "node input points", max_mutations + ) + + base_region_composite_points = { + (p.region_index, p.input_index) for p in base_scheme.child_region_inputs + } + scheme.child_region_inputs = self._mutate_insertion_points( + base_region_composite_points, + full_insertion_scheme.child_region_inputs, + "region composite points", + max_mutations, + ) + + base_region_output_points = { + (p.region_index, p.node_index, p.output_index) for p in base_scheme.region_outputs + } + scheme.region_outputs = self._mutate_insertion_points( + base_region_output_points, + full_insertion_scheme.region_outputs, + "region output points", + max_mutations, + ) + + return scheme + + def _copy_graph(self) -> gs.Graph: + """Create an independent copy of the computation graph.""" + new_graph = gs.import_onnx(self.onnx_model) + new_graph.toposort() + return new_graph + + def _get_quant_dtype(self, quant_type: str) -> np.dtype: + """Get numpy dtype for quantization type.""" + if quant_type == "fp8": + try: + return np.dtype(np.float8_e4m3fn) + except (AttributeError, TypeError): + logger.warning( + "FP8 dtype not available (requires numpy >= 2.0), " + "using uint8 as placeholder. Note: This may not produce " + "correct results without proper FP8 support." + ) + return np.uint8 + + dtype_map = { + "int8": np.int8, + "uint8": np.uint8, + } + + if quant_type not in dtype_map: + logger.warning(f"Unknown quantization type '{quant_type}', defaulting to int8") + return np.int8 + + return dtype_map[quant_type] + + def _get_dq_output_dtype(self, dtype_str: str) -> np.dtype: + """Convert DQ dtype string to numpy dtype.""" + dtype_map = { + "float16": np.float16, + "float32": np.float32, + } + + if hasattr(np, "bfloat16"): + dtype_map["bfloat16"] = np.bfloat16 + + if dtype_str not in dtype_map: + logger.warning(f"Unknown DQ dtype '{dtype_str}', defaulting to float32") + return np.float32 + + return dtype_map[dtype_str] + + def _build_tensor_map(self, graph: gs.Graph) -> dict[str, gs.Tensor]: + """Build mapping from tensor names to tensor objects.""" + tensor_map = {} + + for node in graph.nodes: + for output in node.outputs: + if hasattr(output, "name") and output.name: + tensor_map[output.name] = output + + for input_tensor in graph.inputs: + if hasattr(input_tensor, "name") and input_tensor.name: + tensor_map[input_tensor.name] = input_tensor + + for node in graph.nodes: + for input_tensor in node.inputs: + if ( + isinstance(input_tensor, gs.Constant) + and hasattr(input_tensor, "name") + and input_tensor.name + ): + tensor_map[input_tensor.name] = input_tensor + + return tensor_map + + def _get_tensor_metadata( + self, tensor: gs.Tensor, is_constant: bool + ) -> tuple[tuple | None, np.dtype]: + """Extract shape and dtype metadata from a tensor.""" + default_dtype = self._get_dq_output_dtype(self.config.default_dq_dtype) + + if is_constant and hasattr(tensor, "values") and tensor.values is not None: + return tensor.values.shape, tensor.values.dtype + elif hasattr(tensor, "shape"): + dtype = ( + tensor.dtype + if hasattr(tensor, "dtype") and tensor.dtype is not None + else default_dtype + ) + return tensor.shape, dtype + return None, default_dtype + + def _fix_zero_point_initializers(self, model: onnx.ModelProto) -> None: + """Fix INT8 zero_point initializers to use int32_data instead of raw_data.""" + fixed_count = 0 + + for initializer in model.graph.initializer: + if ( + "_zp_" in initializer.name + and initializer.data_type == onnx.TensorProto.INT8 + and len(initializer.raw_data) > 0 + and len(initializer.int32_data) == 0 + ): + np_array = onnx.numpy_helper.to_array(initializer) + int32_values = np_array.astype(np.int32).flatten().tolist() + + new_tensor = onnx.helper.make_tensor( + initializer.name, + onnx.TensorProto.INT8, + list(initializer.dims), + int32_values, + ) + initializer.CopyFrom(new_tensor) + fixed_count += 1 + + if fixed_count > 0: + logger.debug(f"Fixed {fixed_count} zero_point initializers (int32_data format)") + + def _create_qdq_nodes( + self, + tensor_name: str, + qdq_input: gs.Tensor, + output_shape: tuple | None, + output_dtype: np.dtype, + quant_dtype: np.dtype, + q_scale: float, + ) -> tuple[gs.Node, gs.Node]: + """Create QuantizeLinear and DequantizeLinear node pair.""" + # Create unique names for Q/DQ nodes + q_name = f"QDQ_Q_{tensor_name}".replace("/", "_").replace(":", "_") + dq_name = f"QDQ_DQ_{tensor_name}".replace("/", "_").replace(":", "_") + # Determine scale dtype from output_dtype (fp16/tf32/fp32) + # Scale should match the precision of the original I/O tensor + dtype_map = {"float16": np.float16, "float32": np.float32} + if hasattr(np, "bfloat16"): + dtype_map["bfloat16"] = np.bfloat16 + scale_dtype = dtype_map.get(np.dtype(output_dtype).name, np.float32) + + logger.debug( + f"Creating Q/DQ pair for '{tensor_name}' (scale_dtype={np.dtype(scale_dtype).name})" + ) + + q_scale_values = np.array([q_scale], dtype=scale_dtype) + q_zp_values = np.array([0], dtype=quant_dtype) + q_inputs = [ + qdq_input, + gs.Constant(f"q_scale_{tensor_name}", values=q_scale_values), + gs.Constant(f"q_zp_{tensor_name}", values=q_zp_values), + ] + q_node = gs.Node( + op="QuantizeLinear", + name=q_name, + inputs=q_inputs, + outputs=[ + gs.Variable(f"{tensor_name}_quantized", dtype=quant_dtype, shape=output_shape) + ], + ) + + dq_scale_values = np.array([q_scale], dtype=scale_dtype) + dq_zp_values = np.array([0], dtype=quant_dtype) + dq_inputs = [ + q_node.outputs[0], + gs.Constant(f"dq_scale_{tensor_name}", values=dq_scale_values), + gs.Constant(f"dq_zp_{tensor_name}", values=dq_zp_values), + ] + dq_node = gs.Node( + op="DequantizeLinear", + name=dq_name, + inputs=dq_inputs, + outputs=[ + gs.Variable(f"{tensor_name}_dequantized", dtype=output_dtype, shape=output_shape) + ], + ) + + return q_node, dq_node + + def _insert_qdq_at_tensors( + self, graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] + ) -> None: + """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations.""" + q_scale = self.config.default_q_scale + quant_type = self.config.default_quant_type + quant_dtype = self._get_quant_dtype(quant_type) + + logger.debug(f"Q/DQ parameters: type={quant_type}, scale={q_scale}, zero_point=0") + + resolved_insertion_points = merge_resolved_insertion_points( + graph, resolved_insertion_points + ) + + tensor_map = self._build_tensor_map(graph) + tensor_users_map = get_tensor_consumer_node_indices(graph) + logger.debug( + f"Built tensor maps: {len(tensor_map)} tensors, {len(tensor_users_map)} with users" + ) + + for insertion_point in resolved_insertion_points: + tensor_name = insertion_point.tensor_name + node_index = insertion_point.node_index + input_index = insertion_point.input_index + + original_tensor = tensor_map[tensor_name] + if node_index is not None: + assert node_index < len(graph.nodes), "Node index out of range" + target_node = graph.nodes[node_index] + assert input_index is not None, "Input index must be set when node index is set" + assert input_index < len(target_node.inputs), ( + f"Input index out of range for node {target_node.name}" + ) + original_tensor = target_node.inputs[input_index] + assert tensor_name == original_tensor.name, ( + f"Tensor name mismatch for node {target_node.name} input {input_index}" + ) + else: + assert tensor_name in tensor_map, f"Tensor {tensor_name} not found in tensor map" + assert input_index is None, "Input index must be None when node index is None" + + is_constant = isinstance(original_tensor, gs.Constant) + output_shape, output_dtype = self._get_tensor_metadata(original_tensor, is_constant) + + unique_suffix = "qdq" + if node_index is not None: + unique_suffix = f"n{node_index}_i{input_index}" + unique_tensor_name = f"{tensor_name}_{unique_suffix}" + + q_node, dq_node = self._create_qdq_nodes( + unique_tensor_name, + original_tensor, + output_shape, + output_dtype, + quant_dtype, + q_scale, + ) + + graph.nodes.extend([q_node, dq_node]) + + if node_index is not None: + target_node.inputs[input_index] = dq_node.outputs[0] + logger.debug( + f" Q/DQ inserted: tensor '{tensor_name}' → node #{node_index} " + f"({target_node.name}) input #{input_index}" + ) + else: + users = tensor_users_map[tensor_name] + for user_index in users: + user_node = graph.nodes[user_index] + for i, input_tensor in enumerate(user_node.inputs): + if hasattr(input_tensor, "name") and input_tensor.name == tensor_name: + user_node.inputs[i] = dq_node.outputs[0] + break + logger.debug(f" Q/DQ inserted: tensor '{tensor_name}' → {len(users)} users") + + logger.debug("Running graph cleanup and topological sort") + try: + graph.cleanup().toposort() + logger.debug("Graph cleanup completed") + except Exception as e: + logger.error(f"Graph cleanup failed: {e}") + raise RuntimeError(f"Graph cleanup failed after Q/DQ insertion: {e}") from e + + +class QDQAutotuner(QDQAutotunerBase): + """Q/DQ autotuner with automatic region discovery around compute-intensive ops.""" + + def initialize( + self, config: Config | None = None, pattern_cache: PatternCache | None = None + ) -> None: + """Initialize autotuner and discover optimization regions automatically.""" + super().initialize(config, pattern_cache) + self._search_regions() + + def _visit_region_recursively(self, region: Region) -> list[Region]: + """Recursively traverse region hierarchy and collect all regions.""" + regions = [region] + + for child in region.get_children(): + regions.extend(self._visit_region_recursively(child)) + + return regions + + def _reassign_region_ids(self, regions: list[Region]) -> None: + """Reassign sequential IDs to regions in breadth-first order.""" + region_id = 0 + + queue = deque(regions) + + while queue: + region = queue.popleft() + region.id = region_id + region_id += 1 + queue.extend(region.get_children()) + + def _search_regions(self) -> None: + """Discover and organize optimization regions automatically.""" + logger.info("Discovering optimization regions") + search = CombinedRegionSearch( + self.graph, + maximum_sequence_region_size=self.config.maximum_sequence_region_size, + minimum_topdown_search_size=self.config.minimum_topdown_search_size, + ) + self.regions = search.search_regions() + + self._reassign_region_ids(self.regions) + logger.debug(f"Found {len(self.regions)} top-level regions") + + all_regions = [] + for region in self.regions: + all_regions.extend(self._visit_region_recursively(region)) + + logger.debug(f"Flattened hierarchy to {len(all_regions)} total regions") + + leaf_regions = [region for region in all_regions if region.type == RegionType.LEAF] + other_regions = [region for region in all_regions if region.type != RegionType.LEAF] + + all_regions = leaf_regions + other_regions + self.regions = all_regions + + num_leaf = sum(1 for r in self.regions if r.type == RegionType.LEAF) + num_composite = sum(1 for r in self.regions if r.type == RegionType.COMPOSITE) + num_root = sum(1 for r in self.regions if r.type == RegionType.ROOT) + + logger.info( + f"Discovery complete: {len(self.regions)} regions " + f"({num_leaf} LEAF, {num_composite} COMPOSITE, {num_root} ROOT)" + ) + logger.debug("Regions prioritized: LEAF regions first for profiling") diff --git a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py new file mode 100644 index 000000000..fe4240047 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py @@ -0,0 +1,345 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for QDQAutotuner class. + +Tests the main autotuner class public API. +Note: Full integration tests with TensorRT benchmarking should be in separate integration test files. +""" + +import os +import sys +import tempfile +import unittest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import onnx +import onnx_graphsurgeon as gs +from onnx import helper + +from modelopt.onnx.quantization.autotune import Config, QDQAutotuner, RegionPattern +from modelopt.onnx.quantization.autotune.common import PatternCache, RegionType + + +def create_simple_conv_model(): + """ + Create a simple ONNX model: Input -> Conv -> Relu -> Output. + + This is a minimal model for testing autotuner initialization. + """ + # Input + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + + # Output + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + # Conv node + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + + # Relu node + relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") + + # Create graph + graph = helper.make_graph( + [conv_node, relu_node], + "simple_conv", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + # Create model + model = helper.make_model(graph, producer_name="test") + return model + + +class TestQDQAutotuner(unittest.TestCase): + """Test QDQAutotuner functionality.""" + + @staticmethod + def _create_test_config(): + """ + Create a reasonable config for testing. + + Uses sensible defaults suitable for unit tests: + - verbose=False: Keep test output clean + - maximum_sequence_region_size=50: Allow larger test regions + - Other parameters: Match Config defaults for typical behavior + """ + return Config( + # Logging + verbose=False, + # Performance Requirements + # Quantization Parameters + default_q_scale=0.1, + default_q_zero_point=0, + default_quant_type="int8", + # Region Builder Settings + maximum_sequence_region_size=50, + minimum_topdown_search_size=10, + # Scheme Generation Settings + top_percent_to_mutate=0.1, + minimum_schemes_to_mutate=10, + maximum_mutations=3, + maximum_generation_attempts=100, + # Pattern Cache Settings + pattern_cache_minimum_distance=4, + pattern_cache_max_entries_per_pattern=32, + ) + + def test_creation_with_onnx_model(self): + """Test creating autotuner with ONNX ModelProto.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + assert autotuner is not None + assert autotuner.onnx_model is not None + assert autotuner.graph is not None + + def test_creation_with_gs_graph(self): + """Test creating autotuner with GraphSurgeon graph.""" + model = create_simple_conv_model() + gs_graph = gs.import_onnx(model) + + autotuner = QDQAutotuner(gs_graph) + + assert autotuner is not None + assert autotuner.graph is not None + + def test_initialize_with_default_config(self): + """Test initialization with default test config.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + autotuner.initialize(config) + + # Should have provided config + assert autotuner.config is not None + assert autotuner.config.maximum_sequence_region_size == 50 + + # Should have discovered regions + assert len(autotuner.regions) > 0 + + def test_initialize_with_config(self): + """Test initialization with custom config (different from default).""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + # Create custom config with different values + config = Config( + verbose=True, + default_q_scale=0.05, + default_q_zero_point=128, + default_quant_type="fp8", + maximum_sequence_region_size=20, + minimum_topdown_search_size=5, + top_percent_to_mutate=0.2, + minimum_schemes_to_mutate=5, + maximum_mutations=5, + maximum_generation_attempts=50, + pattern_cache_minimum_distance=2, + pattern_cache_max_entries_per_pattern=16, + ) + autotuner.initialize(config) + + # Should use provided custom config values + assert autotuner.config.verbose + assert autotuner.config.default_q_scale == 0.05 + assert autotuner.config.default_q_zero_point == 128 + assert autotuner.config.default_quant_type == "fp8" + assert autotuner.config.maximum_sequence_region_size == 20 + assert autotuner.config.minimum_topdown_search_size == 5 + assert autotuner.config.top_percent_to_mutate == 0.2 + assert autotuner.config.minimum_schemes_to_mutate == 5 + assert autotuner.config.maximum_mutations == 5 + assert autotuner.config.maximum_generation_attempts == 50 + assert autotuner.config.pattern_cache_minimum_distance == 2 + assert autotuner.config.pattern_cache_max_entries_per_pattern == 16 + + def test_initialize_with_pattern_cache(self): + """Test initialization with pattern cache.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + pattern_cache = PatternCache() + autotuner.initialize(config, pattern_cache=pattern_cache) + + assert autotuner.pattern_cache is not None + + def test_region_discovery(self): + """Test that regions are automatically discovered.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + autotuner.initialize(config) + + # Should discover at least one region + assert len(autotuner.regions) > 0 + + # Regions should be valid + for region in autotuner.regions: + assert region.get_id() is not None + assert region.get_type() in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT] + + def test_export_baseline_model(self): + """Test exporting baseline model without Q/DQ.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: + output_path = f.name + + try: + # Export baseline without Q/DQ insertion + autotuner.export_onnx(output_path, insert_qdq=False) + # Verify file was created + assert os.path.exists(output_path) + # Verify it's a valid ONNX model + exported_model = onnx.load(output_path) + assert exported_model is not None + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def test_set_profile_region(self): + """Test setting a region for profiling.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + # Should set current profile region + assert autotuner.current_profile_region == region + assert autotuner.current_profile_pattern_schemes is not None + else: + self.skipTest("No regions discovered") + + def test_generate_scheme(self): + """Test generating an insertion scheme.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + # Generate a scheme + scheme_idx = autotuner.generate() + # Should return a valid index (>= 0) or -1 if no more unique schemes + assert isinstance(scheme_idx, int) + else: + self.skipTest("No regions discovered") + + def test_submit_latency(self): + """Test submitting performance measurement.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + # Submit baseline latency + autotuner.submit(10.5) + # Baseline should be recorded + assert autotuner.baseline_latency_ms == 10.5 + + def test_save_and_load_state(self): + """Test saving and loading autotuner state.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Submit some results + autotuner.submit(10.5) # baseline + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + state_path = f.name + + try: + # Save state + autotuner.save_state(state_path) + assert os.path.exists(state_path) + + # Create new autotuner and load state + autotuner2 = QDQAutotuner(model) + config2 = self._create_test_config() + autotuner2.initialize(config2) + autotuner2.load_state(state_path) + + # Baseline should match + assert autotuner2.baseline_latency_ms == 10.5 + finally: + if os.path.exists(state_path): + os.unlink(state_path) + + def test_regions_prioritization(self): + """Test that LEAF regions are prioritized.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Check that LEAF regions come before non-LEAF + leaf_indices = [ + i for i, r in enumerate(autotuner.regions) if r.get_type() == RegionType.LEAF + ] + non_leaf_indices = [ + i for i, r in enumerate(autotuner.regions) if r.get_type() != RegionType.LEAF + ] + + if leaf_indices and non_leaf_indices: + # All LEAF should come before non-LEAF + assert max(leaf_indices) < min(non_leaf_indices) + + def test_profiled_patterns_tracking(self): + """Test that profiled patterns are tracked.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + autotuner.submit(10.0) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + + scheme_idx = autotuner.generate() + if scheme_idx >= 0: + autotuner.submit(12.0) + autotuner.set_profile_region(None, commit=True) + pattern_sig = RegionPattern.from_region(region, autotuner.graph).signature + profiled_patterns = [p.pattern.signature for p in autotuner.profiled_patterns] + assert pattern_sig in profiled_patterns + else: + self.skipTest("No regions discovered") From aae60a219395ebec3f7318941d24f221c6fd9edb Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 9 Feb 2026 08:36:26 +0000 Subject: [PATCH 2/8] pick back docstrings Signed-off-by: Will Guo --- .../onnx/quantization/autotune/autotuner.py | 261 ++++++++++++++++-- 1 file changed, 242 insertions(+), 19 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py index 9eb8724dc..86074ed15 100644 --- a/modelopt/onnx/quantization/autotune/autotuner.py +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -51,7 +51,20 @@ class QDQAutotunerBase: """Base class for pattern-based Q/DQ node insertion optimization in ONNX models.""" def __init__(self, model: onnx.ModelProto | gs.Graph): - """Initialize the autotuner with an ONNX model.""" + """Initialize the autotuner with an ONNX model. + + Creates a clean copy of the model graph and initializes internal state. + After construction, call initialize() to configure the autotuner, then + use a subclass strategy to populate regions (e.g., QDQAutotuner does this + automatically during initialize()). + + Args: + model: ONNX model (onnx.ModelProto) or graph (gs.Graph) to optimize. + A clean copy is created internally, leaving the original unchanged. + + Raises: + TypeError: If model is neither onnx.ModelProto nor gs.Graph + """ if isinstance(model, onnx.ModelProto): self.onnx_model = model elif isinstance(model, gs.Graph): @@ -76,7 +89,22 @@ def __init__(self, model: onnx.ModelProto | gs.Graph): def initialize( self, config: Config | None = None, pattern_cache: PatternCache | None = None ) -> None: - """Initialize autotuning session with configuration and pattern cache.""" + """Initialize autotuning session with configuration and pattern cache. + + Prepares the autotuner for profiling by setting configuration parameters + and optionally loading pattern cache data. This base method resets all profiling + state and sets up the pattern cache storage. + + Args: + config: Autotuning configuration parameters. If None, uses default Config(). + Controls Q/DQ parameters, performance thresholds, and scheme generation. + pattern_cache: Optional PatternCache object for seeding with known-good schemes. + If None, creates a new empty pattern cache for tracking best schemes. + If provided, uses existing schemes to warm-start optimization. + + Raises: + None (safe to call multiple times - will reset state each time) + """ if config is not None: self.config = config @@ -109,7 +137,24 @@ def initialize( self.initialized = True def set_profile_region(self, region: Region | None, commit: bool = True) -> None: - """Set the target region for profiling and scheme generation.""" + """Set the target region for profiling and scheme generation. + + This method manages the profiling workflow: + 1. If commit=True: Saves current schemes to profiled_patterns + 2. Creates a RegionPattern from the new region's structure + 3. For pattern-based: tries to seed schemes from pattern cache if available + 4. Sets as current for generate() and submit() calls + + Pass region=None to clear the current profile target without setting a new one. + + Args: + region: The region to profile next (None to clear current target) + commit: If True, commit current schemes to profiled_patterns + before switching. Set to False during initialization. + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -185,13 +230,24 @@ def set_profile_region(self, region: Region | None, commit: bool = True) -> None mode_info = f"seeded with {num_seeded} schemes" if num_seeded > 0 else "starting fresh" logger.info( - f"Profiling region {region.id} [pattern mode, level {region.level}, " - f"size {region.get_size_of_region_and_descendants()}, {mode_info}]" + f"Profiling region {region.id} [level {region.level}, size" + f"{region.get_size_of_region_and_descendants()}, {mode_info}]" ) logger.debug(f"Pattern signature: {region_pattern.signature}") def generate(self) -> int: - """Generate a new Q/DQ insertion scheme for the current pattern or region.""" + """Generate a new Q/DQ insertion scheme for the current pattern or region. + + Creates a new InsertionScheme by mutating the top-performing schemes: + 1. Checks if there are any cached schemes (error=False, latency_ms=inf) + 2. If cached schemes exist, picks one to re-profile + 3. Otherwise, generates a new scheme by mutation + 4. Selects a random scheme from the top 10 performers + 5. Mutates it by adding/removing insertion points + 6. Ensures the new scheme is unique (different from existing schemes) + 7. Adds the scheme to current_profile_pattern_schemes + + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -261,7 +317,28 @@ def generate(self) -> int: def export_onnx( self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False ) -> bytes: - """Export ONNX model with Q/DQ nodes inserted according to tested schemes.""" + """Export ONNX model with Q/DQ nodes inserted according to tested schemes. + + This method creates a modified version of the model by: + 1. For each region, finding the matching pattern + 2. Applying the best scheme for profiled patterns + 3. Applying the current scheme for the active profile pattern + 4. Resolving pattern-relative insertion points to actual tensor names + 5. Inserting Q/DQ pairs at the resolved locations + 6. Converting to FP8 if needed (always creates INT8 first, then converts) + + Args: + output_path: Optional file path where the modified ONNX model will be saved. + If None, the model is not saved to disk and only bytes are returned. + insert_qdq: If True, insert Q/DQ nodes. If False, export unmodified model + (useful for baseline measurements) + + Returns: + bytes: Serialized ONNX model as bytes + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -387,7 +464,19 @@ def export_onnx( return model_bytes def submit(self, latency_ms: float, success: bool = True) -> None: - """Submit performance measurement for the most recently generated scheme.""" + """Submit performance measurement for the most recently generated scheme. + + This method records the measured latency and manages the optimization state: + + Args: + latency_ms: Measured latency in milliseconds (must be > 0) + success: Whether the measurement succeeded. If False, sets scheme.error=True, + logs a warning, and skips speedup calculation. + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + InvalidSchemeError: If no pattern or region is set, or no schemes have been generated + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -458,7 +547,19 @@ def submit(self, latency_ms: float, success: bool = True) -> None: ) def save_state(self, output_path: str) -> None: - """Save complete autotuner state to a YAML file for later reuse.""" + """Save complete autotuner state to a YAML file for later reuse. + + Serializes all optimization results including: + - Baseline latency measurement + - All profiled patterns with their signatures + - All generated schemes with insertion points and latencies + - Configuration parameters + - Current profiling state + + Args: + output_path: File path where the YAML state file will be written. + Pattern cache will be saved to _pattern_cache.yaml + """ current_pattern_sig = None if self.current_profile_pattern_schemes is not None: current_pattern_sig = self.current_profile_pattern_schemes.pattern_signature @@ -498,7 +599,20 @@ def save_state(self, output_path: str) -> None: ) def load_state(self, input_path: str) -> None: - """Load autotuner state from a previously saved YAML file.""" + """Load autotuner state from a previously saved YAML file. + + Restores optimization results from a previous session: + 1. Matches saved patterns to current model's patterns by signature + 2. Loads all schemes with their insertion points and latencies (including unmeasured ones) + 3. Restores baseline latency and configuration + + Args: + input_path: File path to the YAML state file to load + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + FileNotFoundError: If the input_path doesn't exist + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -571,7 +685,20 @@ def load_state(self, input_path: str) -> None: logger.debug(f"No pattern cache file at {cache_path}") def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None: - """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache.""" + """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache. + + Analyzes the current model's regions against the provided quantized tensors + to extract Q/DQ insertion patterns. For each region, creates a pattern cache + entry that captures which insertion points correspond to the quantized tensors. + These cached patterns can then be used as seeds for future autotuning sessions. + + Args: + quantized_tensors: Set or list of tensor names that are quantized + (i.e., tensors that have Q/DQ nodes applied to them) + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -607,7 +734,22 @@ def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> No def _compute_convergence_metrics( self, schemes: list[InsertionScheme], best_scheme: InsertionScheme | None ) -> tuple[int | None, float | None]: - """Compute convergence metrics for a collection of schemes.""" + """Compute convergence metrics for a collection of schemes. + + Analyzes when the best scheme was discovered during the profiling process + by sorting schemes by their profile timestamps and finding the position + of the best scheme. + + Args: + schemes: List of insertion schemes with profile timestamps + best_scheme: The best performing scheme (lowest latency) + + Returns: + Tuple of (samples_before_best, time_to_best) where: + - samples_before_best: Number of samples tested before finding best (0-based index) + - time_to_best: Time in seconds from first sample to best sample + Both values are None if metrics cannot be computed (e.g., missing timestamps) + """ samples_before_best = None time_to_best = None @@ -690,7 +832,29 @@ def _mutate_insertion_points( return [p for p in all_points if key_fn(p) in current_points] def _generate_next_insertion_sample(self) -> InsertionScheme: - """Generate a new insertion scheme by mutating top performers.""" + """Generate a new insertion scheme by mutating top performers. + + This is the core scheme generation algorithm: + 1. Identifies top schemes by latency + 2. Randomly selects one as the base + 3. Mutates node input insertion points (add, remove, or both) + 4. Mutates region composite insertion points (child boundaries) + 5. Mutates region output insertion points + 6. Returns new unique scheme + + **Mutation Strategy:** + - Node input points: Add/remove 1-3 insertion points + - Region composite points: Add/remove 1-3 boundary points + - Region output points: Add/remove 1-3 output points + - Mutation type chosen randomly: 'add', 'remove', or 'both' + + **Baseline Case:** + If no schemes exist yet, returns an empty baseline scheme. + + Returns: + New InsertionScheme with mutated insertion points. + Returns empty scheme if no region is set or no candidates exist. + """ if self.current_profile_region is None: return InsertionScheme() @@ -891,7 +1055,20 @@ def _create_qdq_nodes( quant_dtype: np.dtype, q_scale: float, ) -> tuple[gs.Node, gs.Node]: - """Create QuantizeLinear and DequantizeLinear node pair.""" + """Create QuantizeLinear and DequantizeLinear node pair. + + Args: + tensor_name: Name of the tensor being quantized + qdq_input: Input tensor to the Q node + output_shape: Shape for Q/DQ outputs (may be None) + output_dtype: Dtype for DQ output (also used for scale dtype) + quant_dtype: Dtype for quantized values + quant_type: Quantization type string + q_scale: Quantization scale + + Returns: + Tuple of (q_node, dq_node) + """ # Create unique names for Q/DQ nodes q_name = f"QDQ_Q_{tensor_name}".replace("/", "_").replace(":", "_") dq_name = f"QDQ_DQ_{tensor_name}".replace("/", "_").replace(":", "_") @@ -943,7 +1120,17 @@ def _create_qdq_nodes( def _insert_qdq_at_tensors( self, graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] ) -> None: - """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations.""" + """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations. + + This is the main entry point for Q/DQ insertion. It: + 1. Builds tensor map and tensor-to-users map for efficient lookup + 2. Processes each resolved insertion point to insert Q/DQ nodes + 3. Handles two insertion modes based on node_index + + Args: + graph: Graph to modify in-place + resolved_insertion_points: Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ + """ q_scale = self.config.default_q_scale quant_type = self.config.default_quant_type quant_dtype = self._get_quant_dtype(quant_type) @@ -1031,12 +1218,28 @@ class QDQAutotuner(QDQAutotunerBase): def initialize( self, config: Config | None = None, pattern_cache: PatternCache | None = None ) -> None: - """Initialize autotuner and discover optimization regions automatically.""" + """Initialize autotuner and discover optimization regions automatically. + + Extends base class initialization by automatically searching for regions + after configuration is set up. Regions are discovered using pattern-based + search around compute-intensive operations. + """ super().initialize(config, pattern_cache) self._search_regions() def _visit_region_recursively(self, region: Region) -> list[Region]: - """Recursively traverse region hierarchy and collect all regions.""" + """Recursively traverse region hierarchy and collect all regions. + + Performs depth-first traversal of the region tree starting from a given + region. Collects the root region and all descendant regions (children, + grandchildren, etc.) into a flat list. + + Args: + region: Root region to start traversal from + + Returns: + List of all regions in the subtree (including root), in pre-order DFS. + """ regions = [region] for child in region.get_children(): @@ -1045,7 +1248,15 @@ def _visit_region_recursively(self, region: Region) -> list[Region]: return regions def _reassign_region_ids(self, regions: list[Region]) -> None: - """Reassign sequential IDs to regions in breadth-first order.""" + """Reassign sequential IDs to regions in breadth-first order. + + Traverses the region hierarchy (including children) and assigns new + sequential IDs starting from 0. This ensures clean, predictable region + numbering after region discovery and manipulation. + + Args: + regions: List of top-level regions (children will be processed too) + """ region_id = 0 queue = deque(regions) @@ -1057,7 +1268,19 @@ def _reassign_region_ids(self, regions: list[Region]) -> None: queue.extend(region.get_children()) def _search_regions(self) -> None: - """Discover and organize optimization regions automatically.""" + """Discover and organize optimization regions automatically. + + This is the core region discovery method that: + 1. Runs automatic region search to find optimization targets + 2. Flattens hierarchical structure into a list + 3. Prioritizes LEAF regions (contain actual nodes) + 4. Reassigns IDs for clean indexing + + **Search Strategy:** + Uses CombinedRegionSearch which performs: + - Phase 1: Bottom-up partitioning based on divergence/convergence + - Phase 2: Top-down refinement creating hierarchical structure + """ logger.info("Discovering optimization regions") search = CombinedRegionSearch( self.graph, From bd2ff485430a12f730c56e95cb9bf0f47c8ff280 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Tue, 10 Feb 2026 03:13:53 +0000 Subject: [PATCH 3/8] resolve comments Signed-off-by: Will Guo --- .../unit/onnx/quantization/autotune/autotune/test_autotuner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py index fe4240047..17a8dd4cc 100644 --- a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py +++ b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); From 05f70ab93f30a522c13adcbe739addcc8f5bdf9e Mon Sep 17 00:00:00 2001 From: Will Guo Date: Wed, 11 Feb 2026 13:28:18 +0000 Subject: [PATCH 4/8] resolve comments Signed-off-by: Will Guo --- .../quantization/autotune/autotune/models.py | 47 ++++ .../autotune/autotune/test_autotuner.py | 204 +++++++----------- 2 files changed, 126 insertions(+), 125 deletions(-) create mode 100644 tests/unit/onnx/quantization/autotune/autotune/models.py diff --git a/tests/unit/onnx/quantization/autotune/autotune/models.py b/tests/unit/onnx/quantization/autotune/autotune/models.py new file mode 100644 index 000000000..4090cfef3 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/autotune/models.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Shared test ONNX models for autotuner unit tests. + +Model creation functions live here; tests import and call them directly. +""" + +import onnx +from onnx import helper + + +def _create_simple_conv_onnx_model(): + """Build ONNX model: Input -> Conv -> Relu -> Output (minimal for autotuner tests).""" + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") + graph = helper.make_graph( + [conv_node, relu_node], + "simple_conv", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + return helper.make_model(graph, producer_name="test") diff --git a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py index 17a8dd4cc..ef49e53b5 100644 --- a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py +++ b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py @@ -23,118 +23,82 @@ import os import sys import tempfile -import unittest -# Add parent directory to path -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# Add parent and current directory to path +_test_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.dirname(_test_dir)) +sys.path.insert(0, _test_dir) +import models as _test_models import onnx import onnx_graphsurgeon as gs -from onnx import helper +import pytest from modelopt.onnx.quantization.autotune import Config, QDQAutotuner, RegionPattern from modelopt.onnx.quantization.autotune.common import PatternCache, RegionType -def create_simple_conv_model(): - """ - Create a simple ONNX model: Input -> Conv -> Relu -> Output. - - This is a minimal model for testing autotuner initialization. - """ - # Input - input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) +@pytest.fixture +def simple_conv_model(): + """Simple ONNX model: Input -> Conv -> Relu -> Output. Created via models.py.""" + return _test_models._create_simple_conv_onnx_model() - # Output - output_tensor = helper.make_tensor_value_info( - "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] - ) - # Conv node - conv_node = helper.make_node( - "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" - ) +def _create_test_config(): + """ + Create a reasonable config for testing. - # Relu node - relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") - - # Create graph - graph = helper.make_graph( - [conv_node, relu_node], - "simple_conv", - [input_tensor], - [output_tensor], - initializer=[ - helper.make_tensor( - "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) - ) - ], + Uses sensible defaults suitable for unit tests: + - verbose=False: Keep test output clean + - maximum_sequence_region_size=50: Allow larger test regions + - Other parameters: Match Config defaults for typical behavior + """ + return Config( + # Logging + verbose=False, + # Performance Requirements + # Quantization Parameters + default_q_scale=0.1, + default_q_zero_point=0, + default_quant_type="int8", + # Region Builder Settings + maximum_sequence_region_size=50, + minimum_topdown_search_size=10, + # Scheme Generation Settings + top_percent_to_mutate=0.1, + minimum_schemes_to_mutate=10, + maximum_mutations=3, + maximum_generation_attempts=100, + # Pattern Cache Settings + pattern_cache_minimum_distance=4, + pattern_cache_max_entries_per_pattern=32, ) - # Create model - model = helper.make_model(graph, producer_name="test") - return model - -class TestQDQAutotuner(unittest.TestCase): +class TestQDQAutotuner: """Test QDQAutotuner functionality.""" - @staticmethod - def _create_test_config(): - """ - Create a reasonable config for testing. - - Uses sensible defaults suitable for unit tests: - - verbose=False: Keep test output clean - - maximum_sequence_region_size=50: Allow larger test regions - - Other parameters: Match Config defaults for typical behavior - """ - return Config( - # Logging - verbose=False, - # Performance Requirements - # Quantization Parameters - default_q_scale=0.1, - default_q_zero_point=0, - default_quant_type="int8", - # Region Builder Settings - maximum_sequence_region_size=50, - minimum_topdown_search_size=10, - # Scheme Generation Settings - top_percent_to_mutate=0.1, - minimum_schemes_to_mutate=10, - maximum_mutations=3, - maximum_generation_attempts=100, - # Pattern Cache Settings - pattern_cache_minimum_distance=4, - pattern_cache_max_entries_per_pattern=32, - ) - - def test_creation_with_onnx_model(self): + def test_creation_with_onnx_model(self, simple_conv_model): """Test creating autotuner with ONNX ModelProto.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) + autotuner = QDQAutotuner(simple_conv_model) assert autotuner is not None assert autotuner.onnx_model is not None assert autotuner.graph is not None - def test_creation_with_gs_graph(self): + def test_creation_with_gs_graph(self, simple_conv_model): """Test creating autotuner with GraphSurgeon graph.""" - model = create_simple_conv_model() - gs_graph = gs.import_onnx(model) - + gs_graph = gs.import_onnx(simple_conv_model) autotuner = QDQAutotuner(gs_graph) assert autotuner is not None assert autotuner.graph is not None - def test_initialize_with_default_config(self): + def test_initialize_with_default_config(self, simple_conv_model): """Test initialization with default test config.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) + autotuner = QDQAutotuner(simple_conv_model) - config = self._create_test_config() + config = _create_test_config() autotuner.initialize(config) # Should have provided config @@ -144,10 +108,9 @@ def test_initialize_with_default_config(self): # Should have discovered regions assert len(autotuner.regions) > 0 - def test_initialize_with_config(self): + def test_initialize_with_config(self, simple_conv_model): """Test initialization with custom config (different from default).""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) + autotuner = QDQAutotuner(simple_conv_model) # Create custom config with different values config = Config( @@ -180,23 +143,21 @@ def test_initialize_with_config(self): assert autotuner.config.pattern_cache_minimum_distance == 2 assert autotuner.config.pattern_cache_max_entries_per_pattern == 16 - def test_initialize_with_pattern_cache(self): + def test_initialize_with_pattern_cache(self, simple_conv_model): """Test initialization with pattern cache.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) + autotuner = QDQAutotuner(simple_conv_model) - config = self._create_test_config() + config = _create_test_config() pattern_cache = PatternCache() autotuner.initialize(config, pattern_cache=pattern_cache) assert autotuner.pattern_cache is not None - def test_region_discovery(self): + def test_region_discovery(self, simple_conv_model): """Test that regions are automatically discovered.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) + autotuner = QDQAutotuner(simple_conv_model) - config = self._create_test_config() + config = _create_test_config() autotuner.initialize(config) # Should discover at least one region @@ -207,11 +168,10 @@ def test_region_discovery(self): assert region.get_id() is not None assert region.get_type() in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT] - def test_export_baseline_model(self): + def test_export_baseline_model(self, simple_conv_model): """Test exporting baseline model without Q/DQ.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: @@ -229,11 +189,10 @@ def test_export_baseline_model(self): if os.path.exists(output_path): os.unlink(output_path) - def test_set_profile_region(self): + def test_set_profile_region(self, simple_conv_model): """Test setting a region for profiling.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) if len(autotuner.regions) > 0: @@ -243,13 +202,12 @@ def test_set_profile_region(self): assert autotuner.current_profile_region == region assert autotuner.current_profile_pattern_schemes is not None else: - self.skipTest("No regions discovered") + pytest.skip("No regions discovered") - def test_generate_scheme(self): + def test_generate_scheme(self, simple_conv_model): """Test generating an insertion scheme.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) if len(autotuner.regions) > 0: @@ -260,24 +218,22 @@ def test_generate_scheme(self): # Should return a valid index (>= 0) or -1 if no more unique schemes assert isinstance(scheme_idx, int) else: - self.skipTest("No regions discovered") + pytest.skip("No regions discovered") - def test_submit_latency(self): + def test_submit_latency(self, simple_conv_model): """Test submitting performance measurement.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) # Submit baseline latency autotuner.submit(10.5) # Baseline should be recorded assert autotuner.baseline_latency_ms == 10.5 - def test_save_and_load_state(self): + def test_save_and_load_state(self, simple_conv_model): """Test saving and loading autotuner state.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) # Submit some results @@ -292,8 +248,8 @@ def test_save_and_load_state(self): assert os.path.exists(state_path) # Create new autotuner and load state - autotuner2 = QDQAutotuner(model) - config2 = self._create_test_config() + autotuner2 = QDQAutotuner(simple_conv_model) + config2 = _create_test_config() autotuner2.initialize(config2) autotuner2.load_state(state_path) @@ -303,11 +259,10 @@ def test_save_and_load_state(self): if os.path.exists(state_path): os.unlink(state_path) - def test_regions_prioritization(self): + def test_regions_prioritization(self, simple_conv_model): """Test that LEAF regions are prioritized.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) # Check that LEAF regions come before non-LEAF @@ -322,11 +277,10 @@ def test_regions_prioritization(self): # All LEAF should come before non-LEAF assert max(leaf_indices) < min(non_leaf_indices) - def test_profiled_patterns_tracking(self): + def test_profiled_patterns_tracking(self, simple_conv_model): """Test that profiled patterns are tracked.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) autotuner.submit(10.0) @@ -342,4 +296,4 @@ def test_profiled_patterns_tracking(self): profiled_patterns = [p.pattern.signature for p in autotuner.profiled_patterns] assert pattern_sig in profiled_patterns else: - self.skipTest("No regions discovered") + pytest.skip("No regions discovered") From 0216d47dc6756dca9be90beb69114bb97fb0bbba Mon Sep 17 00:00:00 2001 From: Will Guo Date: Wed, 11 Feb 2026 13:43:28 +0000 Subject: [PATCH 5/8] resolve comments Signed-off-by: Will Guo --- .../autotune/{autotune => }/models.py | 0 .../autotune/{autotune => }/test_autotuner.py | 49 +++++++++++++------ 2 files changed, 35 insertions(+), 14 deletions(-) rename tests/unit/onnx/quantization/autotune/{autotune => }/models.py (100%) rename tests/unit/onnx/quantization/autotune/{autotune => }/test_autotuner.py (88%) diff --git a/tests/unit/onnx/quantization/autotune/autotune/models.py b/tests/unit/onnx/quantization/autotune/models.py similarity index 100% rename from tests/unit/onnx/quantization/autotune/autotune/models.py rename to tests/unit/onnx/quantization/autotune/models.py diff --git a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/test_autotuner.py similarity index 88% rename from tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py rename to tests/unit/onnx/quantization/autotune/test_autotuner.py index ef49e53b5..f6a920bec 100644 --- a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py +++ b/tests/unit/onnx/quantization/autotune/test_autotuner.py @@ -21,14 +21,8 @@ """ import os -import sys import tempfile -# Add parent and current directory to path -_test_dir = os.path.dirname(os.path.abspath(__file__)) -sys.path.insert(0, os.path.dirname(_test_dir)) -sys.path.insert(0, _test_dir) - import models as _test_models import onnx import onnx_graphsurgeon as gs @@ -205,20 +199,47 @@ def test_set_profile_region(self, simple_conv_model): pytest.skip("No regions discovered") def test_generate_scheme(self, simple_conv_model): - """Test generating an insertion scheme.""" + """Test generating multiple schemes and that Q/DQ nodes appear in exported model.""" autotuner = QDQAutotuner(simple_conv_model) config = _create_test_config() autotuner.initialize(config) - if len(autotuner.regions) > 0: - region = autotuner.regions[0] - autotuner.set_profile_region(region) - # Generate a scheme + if len(autotuner.regions) == 0: + pytest.skip("No regions discovered") + + autotuner.submit(10.0) # baseline + region = autotuner.regions[0] + autotuner.set_profile_region(region) + + # Generate multiple schemes and submit a latency for each + num_generated = 0 + while True: scheme_idx = autotuner.generate() - # Should return a valid index (>= 0) or -1 if no more unique schemes + if scheme_idx < 0: + break assert isinstance(scheme_idx, int) - else: - pytest.skip("No regions discovered") + autotuner.submit(10.0 + num_generated * 0.1) # dummy latency + num_generated += 1 + if num_generated >= 5: # cap iterations + break + + assert num_generated > 0, "Expected at least one scheme to be generated" + autotuner.set_profile_region(None, commit=True) + + # Export with Q/DQ and verify Q/DQ nodes are in the model + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: + output_path = f.name + try: + autotuner.export_onnx(output_path, insert_qdq=True) + exported = onnx.load(output_path) + node_ops = [n.op_type for n in exported.graph.node] + assert "QuantizeLinear" in node_ops, "Expected QuantizeLinear nodes in exported model" + assert "DequantizeLinear" in node_ops, ( + "Expected DequantizeLinear nodes in exported model" + ) + finally: + if os.path.exists(output_path): + os.unlink(output_path) def test_submit_latency(self, simple_conv_model): """Test submitting performance measurement.""" From a13c9576cd1c3ef63c729f4853c821fb7931a12c Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 16 Feb 2026 23:46:36 +0000 Subject: [PATCH 6/8] resolve comments Signed-off-by: Will Guo --- .../onnx/quantization/autotune/autotuner.py | 432 +++++++++--------- 1 file changed, 211 insertions(+), 221 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py index 86074ed15..e633f30a1 100644 --- a/modelopt/onnx/quantization/autotune/autotuner.py +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -16,9 +16,10 @@ """Automatic Q/DQ insertion optimization for ONNX models via pattern-based profiling.""" import copy +import functools import os import random -from collections import deque +from collections import Counter, deque from datetime import datetime, timezone import numpy as np @@ -46,10 +47,45 @@ from modelopt.onnx.quantization.fp8 import int8_to_fp8 from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices +_MUTATION_SPECS = [ + ("node_inputs", "node input points", lambda p: (p.node_index, p.input_index)), + ( + "region_composite_inputs", + "region composite points", + lambda p: (p.region_index, p.input_index), + ), + ( + "region_output_points", + "region output points", + lambda p: (p.region_index, p.node_index, p.output_index), + ), +] + + +def _requires_init(method): + """Decorator that raises AutotunerNotInitializedError if initialize() has not been called.""" + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + return method(self, *args, **kwargs) + + return wrapper + class QDQAutotunerBase: """Base class for pattern-based Q/DQ node insertion optimization in ONNX models.""" + _DTYPE_MAP = { + "int8": np.int8, + "uint8": np.uint8, + "float16": np.float16, + "float32": np.float32, + } + def __init__(self, model: onnx.ModelProto | gs.Graph): """Initialize the autotuner with an ONNX model. @@ -86,6 +122,8 @@ def __init__(self, model: onnx.ModelProto | gs.Graph): logger.debug(f"Initialized autotuner with model type: {type(model).__name__}") + requires_init = _requires_init + def initialize( self, config: Config | None = None, pattern_cache: PatternCache | None = None ) -> None: @@ -136,6 +174,54 @@ def initialize( self.initialized = True + def _commit_current_pattern(self, save: bool = True) -> None: + """Save current pattern schemes to profiled_patterns (if save) and clear current state.""" + if save and self.current_profile_pattern_schemes is not None: + num_schemes = len(self.current_profile_pattern_schemes.schemes) + best_scheme = self.current_profile_pattern_schemes.best_scheme + best_latency = best_scheme.latency_ms if best_scheme else float("inf") + + samples_before_best, time_to_best = self._compute_convergence_metrics( + self.current_profile_pattern_schemes.schemes, best_scheme + ) + + logger.info( + f"Pattern complete: {num_schemes} schemes tested, best latency {best_latency:.3f} ms" + ) + logger.debug( + f"Pattern signature: {self.current_profile_pattern_schemes.pattern_signature}" + ) + if samples_before_best is not None: + logger.debug(f"Convergence: best found at sample {samples_before_best}") + if time_to_best is not None: + logger.debug(f"Time to best: {time_to_best:.2f}s") + self.profiled_patterns.append(self.current_profile_pattern_schemes) + + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + + def _seed_from_cache(self, pattern: RegionPattern) -> tuple[PatternSchemes | None, int]: + """Seed PatternSchemes from pattern cache for the given pattern. Returns (schemes, num_seeded).""" + if self.pattern_cache is None: + return None, 0 + cache_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) + if cache_schemes is None or len(cache_schemes.schemes) == 0: + logger.debug("No pattern cache entries for this region") + return None, 0 + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = pattern + num_seeded = 0 + for cached_scheme in cache_schemes.schemes: + scheme_copy = copy.deepcopy(cached_scheme) + scheme_copy.latency_ms = float("inf") + scheme_copy.error = False + pattern_schemes.schemes.append(scheme_copy) + num_seeded += 1 + logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") + return pattern_schemes, num_seeded + + @_requires_init def set_profile_region(self, region: Region | None, commit: bool = True) -> None: """Set the target region for profiling and scheme generation. @@ -155,37 +241,8 @@ def set_profile_region(self, region: Region | None, commit: bool = True) -> None Raises: AutotunerNotInitializedError: If initialize() hasn't been called """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - - if commit: - if self.current_profile_pattern_schemes is not None: - num_schemes = len(self.current_profile_pattern_schemes.schemes) - best_scheme = self.current_profile_pattern_schemes.best_scheme - best_latency = best_scheme.latency_ms if best_scheme else float("inf") - - samples_before_best, time_to_best = self._compute_convergence_metrics( - self.current_profile_pattern_schemes.schemes, best_scheme - ) - - logger.info( - f"Pattern complete: {num_schemes} schemes tested, best latency {best_latency:.3f} ms" - ) - logger.debug( - f"Pattern signature: {self.current_profile_pattern_schemes.pattern_signature}" - ) - if samples_before_best is not None: - logger.debug(f"Convergence: best found at sample {samples_before_best}") - if time_to_best is not None: - logger.debug(f"Time to best: {time_to_best:.2f}s") - self.profiled_patterns.append(self.current_profile_pattern_schemes) - if commit or region is None: - self.current_profile_region = None - self.current_profile_pattern_schemes = None - self.current_insertion_scheme_index = None + self._commit_current_pattern(save=commit) if region is None: return @@ -199,27 +256,7 @@ def set_profile_region(self, region: Region | None, commit: bool = True) -> None logger.debug(f"Pattern signature: {region_pattern.signature}") return - pattern_schemes = None - num_seeded = 0 - - if self.pattern_cache is not None: - cache_schemes = self.pattern_cache.get_pattern_schemes(region_pattern.signature) - - if cache_schemes is not None and len(cache_schemes.schemes) > 0: - pattern_schemes = PatternSchemes() - pattern_schemes.pattern = region_pattern - - for cached_scheme in cache_schemes.schemes: - scheme_copy = copy.deepcopy(cached_scheme) - scheme_copy.latency_ms = float("inf") - scheme_copy.error = False - pattern_schemes.schemes.append(scheme_copy) - num_seeded += 1 - - logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") - else: - logger.debug("No pattern cache entries for this region") - + pattern_schemes, num_seeded = self._seed_from_cache(region_pattern) if pattern_schemes is None: pattern_schemes = PatternSchemes() pattern_schemes.pattern = region_pattern @@ -235,6 +272,7 @@ def set_profile_region(self, region: Region | None, commit: bool = True) -> None ) logger.debug(f"Pattern signature: {region_pattern.signature}") + @_requires_init def generate(self) -> int: """Generate a new Q/DQ insertion scheme for the current pattern or region. @@ -248,11 +286,7 @@ def generate(self) -> int: 7. Adds the scheme to current_profile_pattern_schemes """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - elif self.current_profile_pattern_schemes is None: + if self.current_profile_pattern_schemes is None: raise InvalidSchemeError("No region selected. Call set_profile_region() first.") pattern_schemes = self.current_profile_pattern_schemes @@ -314,6 +348,77 @@ def generate(self) -> int: logger.warning(f"Could not generate unique scheme after {max_attempts} attempts") return -1 + def _resolve_scheme_for_region( + self, region: Region, best: bool + ) -> tuple[InsertionScheme | None, RegionPattern]: + """Resolve the insertion scheme to use for a region from profiled/current/cache. + + Args: + region: The region to resolve the scheme for + best: If True, return the best scheme for the region + + Returns: + tuple[InsertionScheme | None, RegionPattern]: The scheme and pattern for the region + """ + pattern = RegionPattern.from_region(region, self.graph) + logger.debug(f"Region {region.id} (level {region.level})") + logger.debug(f" → Pattern signature: {pattern.signature}") + + matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None) + current_scheme = matched.best_scheme if matched else None + + if matched: + if current_scheme: + logger.debug( + f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)" + ) + else: + logger.debug(" → Matched profiled pattern but no valid schemes") + + if current_scheme is None: + pattern_schemes = self.current_profile_pattern_schemes + if pattern_schemes is None or pattern != pattern_schemes.pattern: + pass + elif best: + current_scheme = pattern_schemes.best_scheme + else: + scheme_index = self.current_insertion_scheme_index + if scheme_index is not None: + assert scheme_index < len(pattern_schemes.schemes), ( + f"Invalid scheme index: {scheme_index}" + ) + current_scheme = pattern_schemes.schemes[scheme_index] + logger.debug(f" → Using current pattern scheme #{scheme_index}") + + if current_scheme is None and self.pattern_cache is not None: + cache_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) + if cache_schemes is not None: + schemes = cache_schemes.schemes + if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled: + current_scheme = schemes[0] + logger.debug(" → Using imported pattern from cache") + + if current_scheme is None: + logger.debug(" → No scheme available, skipping") + + return current_scheme, pattern + + def _exclude_overlapping_insertion_points( + self, + resolved_insertion_points: set[ResolvedInsertionPoint], + region: Region, + pattern: RegionPattern, + ) -> None: + """Remove this region's full insertion points from resolved set so they can be replaced.""" + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + assert full_insertion_scheme is not None + all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) + assert isinstance(all_region_ips, set) + resolved_insertion_points.difference_update(all_region_ips) + if all_region_ips: + logger.debug(f" → Excluded {len(all_region_ips)} overlapping insertion points") + + @_requires_init def export_onnx( self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False ) -> bytes: @@ -339,11 +444,6 @@ def export_onnx( Raises: AutotunerNotInitializedError: If initialize() hasn't been called """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - output_desc = output_path if output_path is not None else "" original_quant_type = self.config.default_quant_type needs_fp8_conversion = insert_qdq and original_quant_type == "fp8" @@ -364,58 +464,13 @@ def export_onnx( logger.debug(f"Resolving Q/DQ insertion points from {len(self.regions)} regions") for region in self.regions: - pattern = RegionPattern.from_region(region, self.graph) - logger.debug(f"Region {region.id} (level {region.level})") - logger.debug(f" → Pattern signature: {pattern.signature}") - - matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None) - current_scheme = matched.best_scheme if matched else None - - if matched: - if current_scheme: - logger.debug( - f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)" - ) - else: - logger.debug(" → Matched profiled pattern but no valid schemes") - - if current_scheme is None: - current_scheme = self.current_profile_pattern_schemes - if current_scheme is None or pattern != current_scheme.pattern: - pass - elif best: - current_scheme = current_scheme.best_scheme - else: - scheme_index = self.current_insertion_scheme_index - if scheme_index is not None: - assert scheme_index < len(current_scheme.schemes), ( - f"Invalid scheme index: {scheme_index}" - ) - current_scheme = current_scheme.schemes[scheme_index] - logger.debug(f" → Using current pattern scheme #{scheme_index}") - - if current_scheme is None and self.pattern_cache is not None: - pattern_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) - if pattern_schemes is not None: - schemes = pattern_schemes.schemes - if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled: - current_scheme = schemes[0] - logger.debug(" → Using imported pattern from cache") - + current_scheme, pattern = self._resolve_scheme_for_region(region, best) if current_scheme is None: - logger.debug(" → No scheme available, skipping") continue - full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) - assert full_insertion_scheme is not None - all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) - assert isinstance(all_region_ips, set) - resolved_insertion_points.difference_update(all_region_ips) - excluded_tensors = all_region_ips - resolved_insertion_points - if excluded_tensors: - logger.debug( - f" → Excluded {len(excluded_tensors)} overlapping insertion points" - ) + self._exclude_overlapping_insertion_points( + resolved_insertion_points, region, pattern + ) new_ips = pattern.matches(region, self.graph, current_scheme) if new_ips: @@ -463,6 +518,7 @@ def export_onnx( ) return model_bytes + @_requires_init def submit(self, latency_ms: float, success: bool = True) -> None: """Submit performance measurement for the most recently generated scheme. @@ -477,11 +533,6 @@ def submit(self, latency_ms: float, success: bool = True) -> None: AutotunerNotInitializedError: If initialize() hasn't been called InvalidSchemeError: If no pattern or region is set, or no schemes have been generated """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - if self.baseline_latency_ms is None: self.baseline_latency_ms = latency_ms logger.info(f"Baseline latency: {latency_ms:.3f} ms") @@ -598,6 +649,7 @@ def save_state(self, output_path: str) -> None: f"{self.pattern_cache.total_schemes} schemes" ) + @_requires_init def load_state(self, input_path: str) -> None: """Load autotuner state from a previously saved YAML file. @@ -613,11 +665,6 @@ def load_state(self, input_path: str) -> None: AutotunerNotInitializedError: If initialize() hasn't been called FileNotFoundError: If the input_path doesn't exist """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - with open(input_path) as f: state = yaml.safe_load(f) @@ -684,6 +731,7 @@ def load_state(self, input_path: str) -> None: else: logger.debug(f"No pattern cache file at {cache_path}") + @_requires_init def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None: """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache. @@ -699,11 +747,6 @@ def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> No Raises: AutotunerNotInitializedError: If initialize() hasn't been called """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - if isinstance(quantized_tensors, list): quantized_tensors = set(quantized_tensors) @@ -782,14 +825,12 @@ def _compute_convergence_metrics( def _is_region_profiled(self, region: Region) -> bool: """Check if a region's pattern has already been fully profiled.""" - - def match_pattern(pattern: PatternSchemes, region: Region) -> bool: - """Check if a pattern matches a region.""" - if pattern.pattern is None or not pattern.pattern.matches(region, self.graph): - return False - return not any(not scheme.is_profiled for scheme in pattern.schemes) - - return any(match_pattern(pattern, region) for pattern in self.profiled_patterns) + return any( + p.pattern is not None + and p.pattern.matches(region, self.graph) + and all(s.is_profiled for s in p.schemes) + for p in self.profiled_patterns + ) def _mutate_insertion_points( self, base_points, all_points, point_type: str, max_mutations: int @@ -904,32 +945,20 @@ def _generate_next_insertion_sample(self) -> InsertionScheme: ) max_mutations = getattr(self.config, "maximum_mutations", 3) - scheme = InsertionScheme() - base_node_points = {(p.node_index, p.input_index) for p in base_scheme.node_inputs} - scheme.node_inputs = self._mutate_insertion_points( - base_node_points, full_insertion_scheme.node_inputs, "node input points", max_mutations - ) - - base_region_composite_points = { - (p.region_index, p.input_index) for p in base_scheme.child_region_inputs - } - scheme.child_region_inputs = self._mutate_insertion_points( - base_region_composite_points, - full_insertion_scheme.child_region_inputs, - "region composite points", - max_mutations, - ) - base_region_output_points = { - (p.region_index, p.node_index, p.output_index) for p in base_scheme.region_outputs - } - scheme.region_outputs = self._mutate_insertion_points( - base_region_output_points, - full_insertion_scheme.region_outputs, - "region output points", - max_mutations, - ) + for attr, point_type, key_fn in _MUTATION_SPECS: + base_points = {key_fn(p) for p in getattr(base_scheme, attr)} + setattr( + scheme, + attr, + self._mutate_insertion_points( + base_points, + getattr(full_insertion_scheme, attr), + point_type, + max_mutations, + ), + ) return scheme @@ -939,9 +968,9 @@ def _copy_graph(self) -> gs.Graph: new_graph.toposort() return new_graph - def _get_quant_dtype(self, quant_type: str) -> np.dtype: - """Get numpy dtype for quantization type.""" - if quant_type == "fp8": + def _resolve_dtype(self, dtype_str: str, default: np.dtype = np.int8) -> np.dtype: + """Resolve a dtype string (quant or DQ output) to a numpy dtype.""" + if dtype_str == "fp8": try: return np.dtype(np.float8_e4m3fn) except (AttributeError, TypeError): @@ -951,63 +980,30 @@ def _get_quant_dtype(self, quant_type: str) -> np.dtype: "correct results without proper FP8 support." ) return np.uint8 - - dtype_map = { - "int8": np.int8, - "uint8": np.uint8, - } - - if quant_type not in dtype_map: - logger.warning(f"Unknown quantization type '{quant_type}', defaulting to int8") - return np.int8 - - return dtype_map[quant_type] - - def _get_dq_output_dtype(self, dtype_str: str) -> np.dtype: - """Convert DQ dtype string to numpy dtype.""" - dtype_map = { - "float16": np.float16, - "float32": np.float32, - } - - if hasattr(np, "bfloat16"): - dtype_map["bfloat16"] = np.bfloat16 - - if dtype_str not in dtype_map: - logger.warning(f"Unknown DQ dtype '{dtype_str}', defaulting to float32") - return np.float32 - - return dtype_map[dtype_str] + if hasattr(np, "bfloat16") and dtype_str == "bfloat16": + return np.bfloat16 + if dtype_str in self._DTYPE_MAP: + return self._DTYPE_MAP[dtype_str] + logger.warning(f"Unknown dtype '{dtype_str}', using default {default}") + return default def _build_tensor_map(self, graph: gs.Graph) -> dict[str, gs.Tensor]: """Build mapping from tensor names to tensor objects.""" - tensor_map = {} - + tensor_map = {t.name: t for t in graph.inputs if hasattr(t, "name") and t.name} for node in graph.nodes: - for output in node.outputs: - if hasattr(output, "name") and output.name: - tensor_map[output.name] = output - - for input_tensor in graph.inputs: - if hasattr(input_tensor, "name") and input_tensor.name: - tensor_map[input_tensor.name] = input_tensor - - for node in graph.nodes: - for input_tensor in node.inputs: - if ( - isinstance(input_tensor, gs.Constant) - and hasattr(input_tensor, "name") - and input_tensor.name - ): - tensor_map[input_tensor.name] = input_tensor - + for t in node.inputs: + if hasattr(t, "name") and t.name: + tensor_map[t.name] = t + for t in node.outputs: + if isinstance(t, gs.Constant) and hasattr(t, "name") and t.name: + tensor_map[t.name] = t return tensor_map def _get_tensor_metadata( self, tensor: gs.Tensor, is_constant: bool ) -> tuple[tuple | None, np.dtype]: """Extract shape and dtype metadata from a tensor.""" - default_dtype = self._get_dq_output_dtype(self.config.default_dq_dtype) + default_dtype = self._resolve_dtype(self.config.default_dq_dtype, np.float32) if is_constant and hasattr(tensor, "values") and tensor.values is not None: return tensor.values.shape, tensor.values.dtype @@ -1133,7 +1129,7 @@ def _insert_qdq_at_tensors( """ q_scale = self.config.default_q_scale quant_type = self.config.default_quant_type - quant_dtype = self._get_quant_dtype(quant_type) + quant_dtype = self._resolve_dtype(quant_type, np.int8) logger.debug(f"Q/DQ parameters: type={quant_type}, scale={q_scale}, zero_point=0") @@ -1227,6 +1223,7 @@ def initialize( super().initialize(config, pattern_cache) self._search_regions() + @staticmethod def _visit_region_recursively(self, region: Region) -> list[Region]: """Recursively traverse region hierarchy and collect all regions. @@ -1288,28 +1285,21 @@ def _search_regions(self) -> None: minimum_topdown_search_size=self.config.minimum_topdown_search_size, ) self.regions = search.search_regions() - self._reassign_region_ids(self.regions) logger.debug(f"Found {len(self.regions)} top-level regions") + # Flatten the hierarchy into a list of all regions all_regions = [] for region in self.regions: all_regions.extend(self._visit_region_recursively(region)) - logger.debug(f"Flattened hierarchy to {len(all_regions)} total regions") - - leaf_regions = [region for region in all_regions if region.type == RegionType.LEAF] - other_regions = [region for region in all_regions if region.type != RegionType.LEAF] - - all_regions = leaf_regions + other_regions + all_regions.sort(key=lambda r: r.type != RegionType.LEAF) self.regions = all_regions - num_leaf = sum(1 for r in self.regions if r.type == RegionType.LEAF) - num_composite = sum(1 for r in self.regions if r.type == RegionType.COMPOSITE) - num_root = sum(1 for r in self.regions if r.type == RegionType.ROOT) - + type_counts = Counter(r.type for r in self.regions) logger.info( f"Discovery complete: {len(self.regions)} regions " - f"({num_leaf} LEAF, {num_composite} COMPOSITE, {num_root} ROOT)" + f"({type_counts[RegionType.LEAF]} LEAF, {type_counts[RegionType.COMPOSITE]} COMPOSITE, " + f"{type_counts[RegionType.ROOT]} ROOT)" ) logger.debug("Regions prioritized: LEAF regions first for profiling") From ebba6e63225eca75dcb79d410189fd532beac230 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 23 Feb 2026 02:16:20 +0000 Subject: [PATCH 7/8] fix test failures Signed-off-by: Will Guo --- .../onnx/quantization/autotune/autotuner.py | 10 +- modelopt/onnx/quantization/autotune/common.py | 547 +++++++++++++++++- .../quantization/autotune/test_autotuner.py | 60 +- 3 files changed, 576 insertions(+), 41 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py index e633f30a1..f4f1adc63 100644 --- a/modelopt/onnx/quantization/autotune/autotuner.py +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -50,12 +50,12 @@ _MUTATION_SPECS = [ ("node_inputs", "node input points", lambda p: (p.node_index, p.input_index)), ( - "region_composite_inputs", + "child_region_inputs", "region composite points", lambda p: (p.region_index, p.input_index), ), ( - "region_output_points", + "region_outputs", "region output points", lambda p: (p.region_index, p.node_index, p.output_index), ), @@ -1224,7 +1224,7 @@ def initialize( self._search_regions() @staticmethod - def _visit_region_recursively(self, region: Region) -> list[Region]: + def _visit_region_recursively(region: Region) -> list[Region]: """Recursively traverse region hierarchy and collect all regions. Performs depth-first traversal of the region tree starting from a given @@ -1240,7 +1240,7 @@ def _visit_region_recursively(self, region: Region) -> list[Region]: regions = [region] for child in region.get_children(): - regions.extend(self._visit_region_recursively(child)) + regions.extend(QDQAutotuner._visit_region_recursively(child)) return regions @@ -1291,7 +1291,7 @@ def _search_regions(self) -> None: # Flatten the hierarchy into a list of all regions all_regions = [] for region in self.regions: - all_regions.extend(self._visit_region_recursively(region)) + all_regions.extend(QDQAutotuner._visit_region_recursively(region)) all_regions.sort(key=lambda r: r.type != RegionType.LEAF) self.regions = all_regions diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py index a8929315a..922ab09eb 100644 --- a/modelopt/onnx/quantization/autotune/common.py +++ b/modelopt/onnx/quantization/autotune/common.py @@ -18,15 +18,22 @@ import hashlib from dataclasses import dataclass, field from enum import Enum -from typing import Any +from typing import TYPE_CHECKING, Any, Optional + +import onnx_graphsurgeon as gs +import yaml from modelopt.onnx.logging_config import logger from modelopt.onnx.quantization.autotune.insertion_points import ( ChildRegionInputInsertionPoint, ChildRegionOutputInsertionPoint, NodeInputInsertionPoint, + ResolvedInsertionPoint, ) +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + class AutotunerError(Exception): """Base exception for autotuner-related errors.""" @@ -315,3 +322,541 @@ def __str__(self) -> str: f"region_output_insertions={len(self.region_outputs)}, " f"latency={self.latency_ms:.3f}ms{error_str})" ) + + +@dataclass +class PatternSchemes: + """Collection of Q/DQ insertion schemes for a single pattern. + + Manages multiple InsertionScheme candidates for a region pattern, tracking + their performance and identifying the best-performing configuration. This + enables pattern-based optimization where all regions with the same structure + use the same Q/DQ insertion strategy. + + **Workflow:** + 1. Pattern is identified from region structure + 2. Multiple schemes are generated and tested + 3. Each scheme is measured (latency_ms) + 4. Best scheme is selected (lowest latency) + 5. Best scheme is applied to all matching regions + + **Best Scheme Selection:** + - Automatically identifies scheme with lowest latency + - Excludes schemes with errors (error=True) + - Schemes with latency_ms = inf are considered unmeasured + - best_scheme property provides easy access to optimal configuration + + **Attributes:** + pattern: RegionPattern defining the structural signature + schemes: List of InsertionScheme candidates with measurements + """ + + pattern: Optional["RegionPattern"] = None # Structural pattern signature + schemes: list[InsertionScheme] = field(default_factory=list) # Candidate schemes + + @property + def pattern_signature(self) -> str: + """Get the pattern signature string.""" + return self.pattern.signature if self.pattern else "" + + @property + def pattern_size(self) -> int: + """Get the pattern size (total node count).""" + return self.pattern.size if self.pattern else 0 + + @property + def best_scheme(self) -> InsertionScheme | None: + """Get the best performing scheme (lowest latency). + + Scans all schemes to find the one with minimum latency_ms, + excluding schemes with errors. + + Returns: + InsertionScheme with lowest latency (excluding error schemes), + or None if no valid schemes exist + """ + if len(self.schemes) == 0: + return None + min_idx, min_latency = -1, float("inf") + for idx, scheme in enumerate(self.schemes): + if not scheme.error and scheme.latency_ms < min_latency: + min_idx = idx + min_latency = scheme.latency_ms + if min_idx < 0: + return None + return self.schemes[min_idx] + + @property + def num_schemes(self) -> int: + """Get total number of schemes.""" + return len(self.schemes) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization. + + Note: Excludes runtime objects like pattern (RegionPattern). + Only serializes metadata and schemes. + """ + return { + "pattern_signature": self.pattern_signature, + "pattern_size": self.pattern_size, + "schemes": [scheme.to_dict() for scheme in self.schemes], + } + + @classmethod + def from_dict( + cls, data: dict[str, Any], pattern: Optional["RegionPattern"] = None + ) -> "PatternSchemes": + """Create PatternSchemes from serialized dictionary. + + Reconstructs the pattern schemes collection from saved data. The + RegionPattern object must be provided separately since it's not + serialized (it's a runtime object computed from the graph). + + If no pattern is provided, creates a minimal RegionPattern from the + saved signature and size for signature matching purposes. + + Args: + data: Dictionary containing 'pattern_signature', 'pattern_size', + and 'schemes' keys + pattern: RegionPattern object to associate (must match signature). + If None, creates minimal pattern from saved data. + + Returns: + Reconstructed PatternSchemes instance + """ + # Import here to avoid circular dependency at runtime + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + ps = cls() + + # If no pattern provided, create minimal one from saved data + if pattern is None and "pattern_signature" in data: + pattern = RegionPattern( + signature=data["pattern_signature"], size=data.get("pattern_size", 0) + ) + + ps.pattern = pattern + + ps.schemes = [ + InsertionScheme.from_dict(scheme_data) for scheme_data in data.get("schemes", []) + ] + + return ps + + def __str__(self) -> str: + """String representation for debugging.""" + best_latency = self.best_scheme.latency_ms if self.best_scheme else 0.0 + return ( + f"PatternSchemes(pattern='{self.pattern_signature[:40]}...', " + f"schemes={self.num_schemes}, best_latency={best_latency:.3f}ms)" + ) + + +@dataclass +class PatternCache: + """Pattern cache containing best-performing schemes for patterns with automatic eviction. + + Stores a collection of PatternSchemes that can be used as seeds for autotuning. + Each PatternSchemes contains high-performing insertion schemes for a specific + pattern signature. The cache automatically evicts non-performant schemes based on: + - Error status (schemes with errors are evicted) + - Duplicate schemes (only better-performing duplicate is kept) + - Similarity (similar schemes where only better-performing one is kept) + - Count limit (only top N best schemes are kept per pattern) + """ + + pattern_schemes: list[PatternSchemes] = field(default_factory=list) + minimum_distance: int = 4 # Minimum distance between schemes in cache + max_entries_per_pattern: int = 32 # Maximum number of schemes per pattern (0 = no limit) + + def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None: + """Add PatternSchemes to pattern cache with automatic eviction of non-performant entries. + + Merges new schemes with existing schemes for the same pattern, automatically + evicting schemes that are non-performant based on multiple criteria. + + Args: + pattern_schemes: PatternSchemes to add to the cache + """ + if not pattern_schemes or not pattern_schemes.pattern: + return + + pattern_sig = pattern_schemes.pattern_signature + + # Find existing PatternSchemes for this pattern + existing_idx = None + for idx, ps in enumerate(self.pattern_schemes): + if ps.pattern_signature == pattern_sig: + existing_idx = idx + break + + # Collect all schemes (existing + new) + all_schemes = list(pattern_schemes.schemes) + if existing_idx is not None: + all_schemes.extend(self.pattern_schemes[existing_idx].schemes) + + # Filter out schemes with errors and deduplicate by hash + valid_schemes = [s for s in all_schemes if not s.error] + unique_schemes = {} + for scheme in valid_schemes: + scheme_hash = scheme.hash + if ( + scheme_hash not in unique_schemes + or scheme.latency_ms < unique_schemes[scheme_hash].latency_ms + ): + unique_schemes[scheme_hash] = scheme + + # Sort by latency to get best schemes + sorted_schemes = sorted(unique_schemes.values(), key=lambda s: s.latency_ms) + + # Apply distance-based filtering if minimum_distance > 0 + if self.minimum_distance > 0: + filtered_schemes = [] + for scheme in sorted_schemes: + # Check if this scheme is too similar to any already-filtered scheme + too_similar = False + for existing_scheme in filtered_schemes: + distance = scheme.distance(existing_scheme) + if distance < self.minimum_distance: + # Schemes are too similar, keep the better one + if scheme.latency_ms < existing_scheme.latency_ms: + # New scheme is better, remove existing and add new + filtered_schemes.remove(existing_scheme) + break + else: + # Existing scheme is better, skip new one + too_similar = True + break + + if not too_similar: + filtered_schemes.append(scheme) + + sorted_schemes = filtered_schemes + + # Apply count limit if max_entries_per_pattern > 0 + # Keep only the top N best-performing schemes per pattern + if self.max_entries_per_pattern > 0: + sorted_schemes = sorted_schemes[: self.max_entries_per_pattern] + + # Create PatternSchemes with all schemes that passed the eviction criteria + result = PatternSchemes(pattern=pattern_schemes.pattern) + result.schemes = sorted_schemes + + # Replace existing or append new + if existing_idx is not None: + self.pattern_schemes[existing_idx] = result + else: + self.pattern_schemes.append(result) + + def get_pattern_schemes(self, pattern_signature: str) -> PatternSchemes | None: + """Get PatternSchemes for a specific pattern signature. + + Args: + pattern_signature: Pattern signature to lookup + + Returns: + PatternSchemes if found, None otherwise + """ + for ps in self.pattern_schemes: + if ps.pattern_signature == pattern_signature: + return ps + return None + + def has_pattern(self, pattern_signature: str) -> bool: + """Check if pattern cache contains a specific pattern. + + Args: + pattern_signature: Pattern signature to check + + Returns: + True if pattern exists in pattern cache + """ + return any(ps.pattern_signature == pattern_signature for ps in self.pattern_schemes) + + def add_pattern_from_region( + self, region: Region, graph: gs.Graph, quantized_tensors: set[str] + ) -> None: + """Build and add a pattern cache entry from a region in a quantized model. + + Analyzes a region from an already-quantized model to extract its Q/DQ + insertion scheme. This allows capturing known-good quantization strategies + from existing models and using them as seeds for autotuning. + + Args: + region: Region from the quantized model to analyze + graph: ONNX graph containing the region + quantized_tensors: Set of tensor names that have Q/DQ nodes + + Example: + >>> cache = PatternCache() + >>> for region in all_regions: + ... cache.add_pattern_from_region(region, graph, quantized_tensors) + >>> cache.save("learned_patterns.yaml") + """ + # Import here to avoid circular dependency at runtime + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + # Create pattern from region + pattern = RegionPattern.from_region(region, graph) + # Track insertion points + scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + # Analyze node inputs + full_insertion_scheme = pattern.get_full_insertion_scheme(region, graph) + for point in full_insertion_scheme.node_inputs: + temp_scheme = InsertionScheme( + node_inputs=[point], + child_region_inputs=[], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + temp_insertion_points: list[ResolvedInsertionPoint] = pattern.matches( + region, graph, temp_scheme + ) + temp_tensor_names = {tensor.tensor_name for tensor in temp_insertion_points} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.node_inputs.append(point) + # Analyze region boundaries (for COMPOSITE regions) + if region.type == RegionType.COMPOSITE: + for child_point in full_insertion_scheme.child_region_inputs: + temp_scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[child_point], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + temp_insertion_points = pattern.matches(region, graph, temp_scheme) + temp_tensor_names = {tensor.tensor_name for tensor in temp_insertion_points} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.child_region_inputs.append(child_point) + # Analyze region outputs + for output_point in full_insertion_scheme.region_outputs: + temp_scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[], + region_outputs=[output_point], + latency_ms=float("inf"), + error=False, + ) + temp_insertion_points = pattern.matches(region, graph, temp_scheme) + temp_tensor_names = {tensor.tensor_name for tensor in temp_insertion_points} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.region_outputs.append(output_point) + # Add pattern and scheme to pattern cache + pattern_schemes = PatternSchemes(pattern=pattern, schemes=[scheme]) + self.add_pattern_schemes(pattern_schemes) + num_points = ( + len(scheme.node_inputs) + len(scheme.child_region_inputs) + len(scheme.region_outputs) + ) + logger.debug(f"Added pattern from region {region.id} with {num_points} insertion points") + # Add patterns from child regions + if region.type == RegionType.COMPOSITE: + for child_region in region.get_children(): + self.add_pattern_from_region(child_region, graph, quantized_tensors) + + @property + def num_patterns(self) -> int: + """Get number of patterns in pattern cache.""" + return len(self.pattern_schemes) + + @property + def total_schemes(self) -> int: + """Get total number of schemes across all patterns.""" + return sum(ps.num_schemes for ps in self.pattern_schemes) + + def merge(self, other: "PatternCache", prefer_existing: bool = True) -> None: + """Merge another PatternCache into this one. + + Args: + other: PatternCache to merge + prefer_existing: If True, keep existing patterns when there's a conflict. + If False, overwrite with other's patterns. + """ + for schemes in other.pattern_schemes: + if not self.has_pattern(schemes.pattern_signature) or not prefer_existing: + self.add_pattern_schemes(schemes) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization. + + Returns: + Dictionary with 'minimum_distance', 'max_entries_per_pattern', and 'pattern_schemes' keys + """ + return { + "minimum_distance": self.minimum_distance, + "max_entries_per_pattern": self.max_entries_per_pattern, + "pattern_schemes": [ps.to_dict() for ps in self.pattern_schemes], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "PatternCache": + """Create PatternCache from serialized dictionary. + + Note: RegionPattern objects are not restored (they're runtime objects). + Only pattern signatures and scheme data are loaded. + + Args: + data: Dictionary containing pattern cache data + + Returns: + Reconstructed PatternCache instance + """ + cache = cls( + minimum_distance=data.get("minimum_distance", 4), + max_entries_per_pattern=data.get("max_entries_per_pattern", 32), + ) + + for ps_data in data.get("pattern_schemes", []): + # Create PatternSchemes without pattern object (pattern=None) + ps = PatternSchemes.from_dict(ps_data, pattern=None) + cache.pattern_schemes.append(ps) + + return cache + + def save(self, output_path: str) -> None: + """Save pattern cache to a YAML file. + + Serializes all pattern schemes and their insertion points to a YAML file + that can be loaded later for seeded autotuning. The format matches the + autotuner state file format for consistency. + + Args: + output_path: File path where the YAML pattern cache file will be written + """ + state = self.to_dict() + + with open(output_path, "w") as f: + yaml.dump(state, f, default_flow_style=False, sort_keys=False) + + logger.info( + f"Saved pattern cache → {output_path} ({self.num_patterns} patterns, " + f"{self.total_schemes} schemes)" + ) + logger.debug( + f"Cache settings: min_distance={self.minimum_distance}, " + f"max_per_pattern={self.max_entries_per_pattern}" + ) + + @classmethod + def load(cls, input_path: str) -> "PatternCache": + """Load pattern cache from a YAML file. + + Reads a previously saved pattern cache file and reconstructs all pattern + schemes. The loaded pattern cache can be used to seed autotuning with + known-good insertion schemes. + + Args: + input_path: File path to the YAML pattern cache file to load + + Returns: + PatternCache instance with all pattern schemes loaded + + Raises: + FileNotFoundError: If the input_path doesn't exist + """ + with open(input_path) as f: + state = yaml.safe_load(f) + + cache = cls.from_dict(state) + + logger.info( + f"Loaded pattern cache from {input_path} ({cache.num_patterns} patterns, " + f"{cache.total_schemes} schemes)" + ) + logger.debug( + f"Cache settings: min_distance={cache.minimum_distance}, " + f"max_per_pattern={cache.max_entries_per_pattern}" + ) + + return cache + + def __str__(self) -> str: + """String representation for debugging.""" + return ( + f"PatternCache(patterns={self.num_patterns}, " + f"schemes={self.total_schemes}, " + f"minimum_distance={self.minimum_distance}, " + f"max_entries_per_pattern={self.max_entries_per_pattern})" + ) + + +@dataclass +class Config: + """Configuration parameters for QDQ autotuning. + + Controls the autotuning process including performance requirements, quantization + parameters, region building, scheme generation, and finetuning behavior. + + Attributes: + # Logging + verbose: Enable detailed logging of autotuning progress (default: False) + + # Performance Requirements + performance_threshold: Minimum speedup ratio to accept a scheme. + 1.0 = no improvement required, 1.02 = 2% improvement (default: 1.02) + + # Quantization Parameters + default_q_scale: Default scale parameter for Q/DQ nodes. Controls quantization + granularity. Typical range: 0.01-0.1 (default: 0.1) + default_q_zero_point: Default zero-point for Q/DQ nodes. Use 0 for signed int8, + 128 for unsigned uint8 (default: 0) + default_quant_type: Quantization type for Q/DQ nodes. Options: "int8" (default), "fp8" + + # Region Builder Settings + maximum_sequence_region_size: Maximum number of nodes in a sequence region during + top-down refinement. Prevents overly large merged regions (default: 10) + minimum_topdown_search_size: Minimum number of nodes in a region to trigger + top-down search during region building (default: 10) + + # Scheme Generation Settings + top_percent_to_mutate: Top percentage of best schemes to use as mutation seeds + during scheme generation. Range: 0.0-1.0 (default: 0.1 = top 10%) + minimum_schemes_to_mutate: Minimum number of schemes to keep as mutation seeds, + even if top_percent_to_mutate results in fewer (default: 10) + maximum_mutations: Maximum number of mutations to apply to a single scheme + during generation (default: 3) + maximum_generation_attempts: Maximum attempts to generate a unique new scheme + before giving up (default: 100) + + # Pattern Cache Settings + pattern_cache_minimum_distance: Minimum edit distance required between schemes in cache. + When adding schemes, if a scheme is too similar (distance < minimum_distance) + to an existing scheme, only the better-performing one is kept (default: 4) + pattern_cache_max_entries_per_pattern: Maximum number of schemes to keep per pattern + in pattern cache. Only the top N best-performing schemes are kept for each pattern. + Use 0 to keep all schemes (default: 32) + """ + + # Logging + verbose: bool = False + + # Performance Requirements + performance_threshold: float = 1.02 + + # Quantization Parameters + default_q_scale: float = 0.1 + default_q_zero_point: int = 0 + default_quant_type: str = "int8" + default_dq_dtype: str = "float32" + + # Region Builder Settings + maximum_sequence_region_size: int = 10 + minimum_topdown_search_size: int = 10 + + # Scheme Generation Settings + top_percent_to_mutate: float = 0.1 + minimum_schemes_to_mutate: int = 10 + maximum_mutations: int = 3 + maximum_generation_attempts: int = 100 + + # Pattern Cache Settings + pattern_cache_minimum_distance: int = 4 + pattern_cache_max_entries_per_pattern: int = 32 diff --git a/tests/unit/onnx/quantization/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/test_autotuner.py index f6a920bec..b64cb23b1 100644 --- a/tests/unit/onnx/quantization/autotune/test_autotuner.py +++ b/tests/unit/onnx/quantization/autotune/test_autotuner.py @@ -159,8 +159,8 @@ def test_region_discovery(self, simple_conv_model): # Regions should be valid for region in autotuner.regions: - assert region.get_id() is not None - assert region.get_type() in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT] + assert region.id is not None + assert region.type in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT] def test_export_baseline_model(self, simple_conv_model): """Test exporting baseline model without Q/DQ.""" @@ -207,39 +207,33 @@ def test_generate_scheme(self, simple_conv_model): if len(autotuner.regions) == 0: pytest.skip("No regions discovered") - autotuner.submit(10.0) # baseline + autotuner.submit(10.0) region = autotuner.regions[0] autotuner.set_profile_region(region) - # Generate multiple schemes and submit a latency for each - num_generated = 0 - while True: - scheme_idx = autotuner.generate() - if scheme_idx < 0: - break - assert isinstance(scheme_idx, int) - autotuner.submit(10.0 + num_generated * 0.1) # dummy latency - num_generated += 1 - if num_generated >= 5: # cap iterations - break - - assert num_generated > 0, "Expected at least one scheme to be generated" - autotuner.set_profile_region(None, commit=True) - - # Export with Q/DQ and verify Q/DQ nodes are in the model with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: output_path = f.name - try: - autotuner.export_onnx(output_path, insert_qdq=True) - exported = onnx.load(output_path) - node_ops = [n.op_type for n in exported.graph.node] - assert "QuantizeLinear" in node_ops, "Expected QuantizeLinear nodes in exported model" - assert "DequantizeLinear" in node_ops, ( - "Expected DequantizeLinear nodes in exported model" + + has_q = False + has_dq = False + for _ in range(5): + scheme_idx = autotuner.generate() + assert isinstance(scheme_idx, int) + autotuner.submit(10.0 + _ * 0.1) + + autotuner.export_onnx(output_path, insert_qdq=True) + exported = onnx.load(output_path) + node_ops = [n.op_type for n in exported.graph.node] + for node_op in node_ops: + if node_op == "QuantizeLinear": + has_q = True + if node_op == "DequantizeLinear": + has_dq = True + if has_q and has_dq: + break + assert has_q and has_dq, ( + "Expected QuantizeLinear and DequantizeLinear nodes in exported model" ) - finally: - if os.path.exists(output_path): - os.unlink(output_path) def test_submit_latency(self, simple_conv_model): """Test submitting performance measurement.""" @@ -287,12 +281,8 @@ def test_regions_prioritization(self, simple_conv_model): autotuner.initialize(config) # Check that LEAF regions come before non-LEAF - leaf_indices = [ - i for i, r in enumerate(autotuner.regions) if r.get_type() == RegionType.LEAF - ] - non_leaf_indices = [ - i for i, r in enumerate(autotuner.regions) if r.get_type() != RegionType.LEAF - ] + leaf_indices = [i for i, r in enumerate(autotuner.regions) if r.type == RegionType.LEAF] + non_leaf_indices = [i for i, r in enumerate(autotuner.regions) if r.type != RegionType.LEAF] if leaf_indices and non_leaf_indices: # All LEAF should come before non-LEAF From 6124ac9335cc0a751c2065ffac085e3948137817 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 23 Feb 2026 02:41:56 +0000 Subject: [PATCH 8/8] move models to utils Signed-off-by: Will Guo --- .../onnx/quantization/autotune/models.py | 0 tests/unit/onnx/quantization/autotune/test_autotuner.py | 6 +++--- 2 files changed, 3 insertions(+), 3 deletions(-) rename tests/{unit => _test_utils}/onnx/quantization/autotune/models.py (100%) diff --git a/tests/unit/onnx/quantization/autotune/models.py b/tests/_test_utils/onnx/quantization/autotune/models.py similarity index 100% rename from tests/unit/onnx/quantization/autotune/models.py rename to tests/_test_utils/onnx/quantization/autotune/models.py diff --git a/tests/unit/onnx/quantization/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/test_autotuner.py index b64cb23b1..26e390a23 100644 --- a/tests/unit/onnx/quantization/autotune/test_autotuner.py +++ b/tests/unit/onnx/quantization/autotune/test_autotuner.py @@ -23,10 +23,10 @@ import os import tempfile -import models as _test_models import onnx import onnx_graphsurgeon as gs import pytest +from _test_utils.onnx.quantization.autotune.models import _create_simple_conv_onnx_model from modelopt.onnx.quantization.autotune import Config, QDQAutotuner, RegionPattern from modelopt.onnx.quantization.autotune.common import PatternCache, RegionType @@ -34,8 +34,8 @@ @pytest.fixture def simple_conv_model(): - """Simple ONNX model: Input -> Conv -> Relu -> Output. Created via models.py.""" - return _test_models._create_simple_conv_onnx_model() + """Simple ONNX model: Input -> Conv -> Relu -> Output. Created via _test_utils models.""" + return _create_simple_conv_onnx_model() def _create_test_config():