Skip to content

Generic Fused MoE Quantization + Export for transformers 5.0+#1187

Open
Edwardf0t1 wants to merge 10 commits intomainfrom
zhiyu/ptq-export-transformers5
Open

Generic Fused MoE Quantization + Export for transformers 5.0+#1187
Edwardf0t1 wants to merge 10 commits intomainfrom
zhiyu/ptq-export-transformers5

Conversation

@Edwardf0t1
Copy link
Copy Markdown
Contributor

@Edwardf0t1 Edwardf0t1 commented Apr 7, 2026

What does this PR do?

Add generic quantization and export support for fused MoE expert modules in HuggingFace transformers 5.0+.

In transformers 5.0+, all major MoE models switched from sequential per-expert nn.ModuleList to fused 3D tensor parameters (gate_up_proj, down_proj). This breaks ModelOpt's existing per-expert quantization and export pipeline, which assumes iterable expert submodules.

Affected models (verified against transformers v5.5.0 source):

  • MixtralExperts (Mixtral)
  • Qwen2MoeExperts (Qwen2-MoE)
  • Qwen3MoeExperts (Qwen3-MoE)
  • Qwen3_5MoeExperts (Qwen3.5-MoE)
  • DeepseekV3NaiveMoe (DeepSeek-V3)
  • JambaExperts, OlmoeExperts, and any future model following the same HF standard pattern

Key insight: All these models share an identical fused expert structure and forward pattern. A single generic solution replaces N model-specific implementations.

Context: relationship to PR #975 and PR #1170

Changes

Quantization (modelopt/torch/quantization/plugins/huggingface.py):

  • _QuantFusedExperts(_QuantFunctionalMixin) -- Generic wrapper that intercepts F.linear calls and applies per-expert quantization via storage-offset-based expert index recovery. Each expert gets its own weight and input quantizers (nn.ModuleList).
  • _is_fused_experts_module() -- Structural detector: gate_up_proj (3D) + down_proj (3D) + num_experts + act_fn.
  • register_fused_experts_on_the_fly() -- Auto-registration callback, added to CUSTOM_MODEL_PLUGINS before register_sparse_moe_on_the_fly so explicit registrations (Llama4, GptOss, etc.) take priority.
  • _get_fused_expert_intermediate_dim() -- Helper for cross-version attribute name resolution (intermediate_dim / intermediate_size / fallback to shape).

Export (modelopt/torch/export/moe_utils.py, unified_export_hf.py, layer_utils.py):

  • _export_fused_experts() -- Splits fused 3D weights into per-expert 2D projections (gate_proj, up_proj, down_proj), handles amax fallback for uncalibrated experts, proportionally slices per-channel amax, and registers results under the standard experts.{E}.gate_proj.weight naming convention.
  • Integration in _process_quantized_modules and _export_transformers_checkpoint to dispatch to _export_fused_experts for fused expert modules.
  • Structural detection in get_expert_linear_names for fused experts. Added MixtralSparseMoeBlock to the gate_proj/down_proj/up_proj group (transformers 5.0 naming).

Tests (tests/unit/torch/quantization/plugins/test_fused_experts.py):

  • Synthetic fused expert model matching the exact HF 5.0+ pattern.
  • Tests for structural detection, auto-registration, two-level registration (block + expert), quantizer creation, forward pass-through correctness, expert index recovery, and export output structure.

Two-level registration design

SparseMoeBlock --> _QuantSparseMoe (calibration control, token counting, top_k override) .experts --> _QuantFusedExperts (per-expert F.linear interception + quantization)

register_fused_experts_on_the_fly runs first to register the inner expert module; register_sparse_moe_on_the_fly then registers the outer block. _QuantSparseMoe.layer_sync_moe_local_experts_amax skips fused experts (they are not iterable), as per-expert amax is managed internally by _QuantFusedExperts.

Known limitations

  • @use_experts_implementation backends: The F.linear interception only works with experts_implementation="eager" (default). batched_mm / grouped_mm use torch.bmm / torch._grouped_mm instead and are not intercepted.
  • Storage offset fragility: Expert index recovery via storage_offset() breaks under .contiguous(), FSDP2 redistribution, or torch.compile materialization. Runtime assertions are included.
  • Toggle state machine: Assumes exactly 2 F.linear calls per expert. Documented in docstrings.
  • Non-standard MoE models: DBRX, GptOss, Llama4, Step3p5 have different layouts and are already explicitly handled. The generic solution does not attempt to cover these.

Testing

  • Unit tests with synthetic fused expert model: detection, registration, quantization, export
  • Verify existing sequential MoE tests still pass (test_sparse_moe.py)
  • GPU test with a real MoE model on transformers 5.x

Before your PR is "Ready for review"

  • Is this change backward compatible?: Yes -- existing explicit registrations take priority; sequential MoE models are unaffected.
  • Did you write any new necessary tests?: Yes
  • Did you update Changelog?: No (pending)

Summary by CodeRabbit

  • New Features

    • Added quantization support for fused Mixture-of-Experts (MoE) modules with automatic detection, per-expert quantization handling, and export to per-expert submodules; unified checkpoint export now supports fused MoE experts.
  • Tests

    • Added end-to-end tests covering fused-experts detection, conversion, forward correctness, expert index recovery, and export.
  • Changelog

    • Updated release notes to announce fused MoE expert support for Hugging Face exports.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 7, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 7, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: df4cd47b-33c3-4de4-80d4-cc2a288d3984

📥 Commits

Reviewing files that changed from the base of the PR and between 3cceca6 and d4f06b5.

📒 Files selected for processing (1)
  • CHANGELOG.rst
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.rst

📝 Walkthrough

Walkthrough

Adds detection, quantization wrapping, and export conversion for fused HuggingFace MoE expert modules: new fused-expert wrapper and helpers, export routine to split fused gate/up/down weights into per-expert submodules with quantization state, and integration into the unified HF quantization/export pipeline.

Changes

Cohort / File(s) Summary
Layer utilities
modelopt/torch/export/layer_utils.py
Detect fused MoE experts in get_expert_linear_names() by checking gate_up_proj_weight_quantizers; return fused or legacy expert name mappings.
Fused-expert export logic
modelopt/torch/export/moe_utils.py
Added _export_fused_experts(module, dtype) to split fused gate_up_proj/down_proj into per-expert gate_proj/up_proj/down_proj, copy/slice quantizers, register per-expert submodules and exported buffers, and remove fused attributes.
Unified HF export integration
modelopt/torch/export/unified_export_hf.py
Detect fused-expert containers during quantized-module processing and invoke _export_fused_experts; short-circuit expert amax calibration loop for fused layouts.
Quantization plugin (HuggingFace)
modelopt/torch/quantization/plugins/huggingface.py
Added _QuantFusedExperts, helpers for intermediate-dim detection and expert-index recovery, and register_fused_experts_on_the_fly to register fused-expert wrappers with the quantization registry.
Tests
tests/unit/torch/quantization/plugins/test_fused_experts.py
New tests for detection, registration, conversion, forward equivalence, expert-index recovery, and _export_fused_experts using synthetic fused-expert and sparse-MoE constructs.
Changelog
CHANGELOG.rst
Updated Experimental note to list unified HF checkpoint export support now includes fused MoE expert modules and several MoE families.

Sequence Diagrams

sequenceDiagram
    participant Model as HF Model
    participant Plugin as Quant Plugin
    participant Detector as _is_fused_experts_module()
    participant Registry as QuantModuleRegistry
    participant Wrapper as _QuantFusedExperts

    Model->>Plugin: Load model with fused MoE experts
    Plugin->>Detector: Inspect submodule shape/attrs
    Detector-->>Registry: Register fused-expert wrapper if matched
    Registry->>Wrapper: Instantiate wrapper for expert container
    Wrapper-->>Model: Replace/attach quantized expert wrapper
    Note over Wrapper: shared input quantizers\nper-expert weight quantizers
Loading
sequenceDiagram
    participant Module as Fused Expert Module
    participant Export as _export_fused_experts()
    participant Split as Weight Slicer
    participant Quant as Quantizer Handler
    participant Register as Submodule Registration

    Export->>Module: Identify fused params & num_experts
    loop for each expert idx
        Export->>Split: Slice gate_up_proj -> gate_proj + up_proj
        Export->>Split: Slice down_proj -> down_proj[idx]
        Export->>Quant: Deep-copy per-expert quantizers, slice _amax or compute if missing
        Quant-->>Export: Exported quantized weight buffers
        Export->>Register: Create/register gate_proj/up_proj/down_proj under module.{idx}
    end
    Export->>Module: Delete fused params and quantizer lists
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 46.15% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Security Anti-Patterns ❓ Inconclusive Search for unsafe patterns in Python files (torch.load, numpy.load, eval/exec, hardcoded trust_remote_code, nosec comments). No modified or new Python files provided for scanning. Please provide the file list and content for verification.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately summarizes the main change: adding generic quantization and export support for fused MoE expert modules in transformers 5.0+, which aligns perfectly with the core functionality across all modified files.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch zhiyu/ptq-export-transformers5

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 7, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1187/

Built to branch gh-pages at 2026-04-09 04:44 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 7, 2026

Codecov Report

❌ Patch coverage is 88.69565% with 13 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.42%. Comparing base (04cd596) to head (d4f06b5).

Files with missing lines Patch % Lines
modelopt/torch/export/moe_utils.py 86.66% 6 Missing ⚠️
modelopt/torch/export/unified_export_hf.py 33.33% 4 Missing ⚠️
modelopt/torch/quantization/plugins/huggingface.py 96.66% 2 Missing ⚠️
modelopt/torch/export/layer_utils.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1187      +/-   ##
==========================================
+ Coverage   75.68%   76.42%   +0.74%     
==========================================
  Files         353      353              
  Lines       40491    40606     +115     
==========================================
+ Hits        30644    31035     +391     
+ Misses       9847     9571     -276     
Flag Coverage Δ
examples 42.84% <23.47%> (+1.11%) ⬆️
gpu 56.72% <36.52%> (-0.39%) ⬇️
unit 55.35% <84.34%> (+0.11%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Edwardf0t1 Edwardf0t1 marked this pull request as ready for review April 7, 2026 20:31
@Edwardf0t1 Edwardf0t1 requested review from a team as code owners April 7, 2026 20:31
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/export/layer_utils.py (1)

968-995: ⚠️ Potential issue | 🟠 Major

Generic support is still incomplete for the AWQ/SVDQuant resmooth path.

get_expert_linear_names() now recognizes fused blocks, but requantize_resmooth_fused_llm_layers() still goes through get_experts_list(), and that helper only accepts Mixtral/Qwen model types. For the new targets called out in this PR (DeepSeek, Jamba, OLMoE, etc.), AWQ/NVFP4_SVDQuant export will still hit NotImplementedError before _export_fused_experts() is reached. Please add a structural fallback there, or bypass the old model-type path for fused experts entirely.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/layer_utils.py` around lines 968 - 995, The
requantize_resmooth_fused_llm_layers flow fails for new fused expert types
because requantize_resmooth_fused_llm_layers still relies on get_experts_list
(which only knows Mixtral/Qwen) and raises NotImplementedError before
_export_fused_experts is reached; update requantize_resmooth_fused_llm_layers
(or get_experts_list) to use the same structural detection as
get_expert_linear_names() (e.g., check hasattr(module, "experts") and
hasattr(module.experts, "gate_up_proj_weight_quantizers") or call
get_expert_linear_names(module)) as a fallback path so fused experts (DeepSeek,
Jamba, OLMoE, etc.) are handled generically instead of raising
NotImplementedError, or bypass the old model-type branch entirely and directly
dispatch to _export_fused_experts() when the structural checks match.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 871-879: _the detector is too permissive and may misidentify
BMM-style expert containers as F.linear-style fused experts; tighten
_is_fused_experts_module() (and the fallback in
_get_fused_expert_intermediate_dim()) by explicitly validating tensor
orientations expected by _QuantFusedExperts: require gate_up_proj and
gate_down_proj to be 2-D, experts to be 3-D, compute candidate_intermediate =
gate_up_proj.shape[1] // 2 and assert gate_up_proj.shape[1] is even, then verify
experts.shape[2] == candidate_intermediate and experts.shape[0] ==
int(getattr(module, "num_experts", experts.shape[0])); only auto-register the
_QuantFusedExperts path in CUSTOM_MODEL_PLUGINS when those orientation checks
pass so alternative layouts fall back to their own handlers (e.g.,
_QuantLlama4TextExperts).
- Around line 906-925: The computed idx must only be trusted if the slice truly
shares the underlying storage with self.gate_up_proj; first check that
weight.storage().data_ptr() (or weight.storage() identity) equals
self.gate_up_proj.storage().data_ptr() and raise a clear error if it does not,
instead of silently proceeding; then compute idx from storage_offset() and
stride as before and keep the existing range assertion using
_get_expert_idx_from_gate_up, gate_up_proj, storage_offset, stride, and
num_experts to locate and fix the code path.

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 229-230: The assertion comparing out_ref and out_test is too
strict (atol=1e-5) causing CI flakiness; update the test in
test_fused_experts.py (the comparison of out_ref and out_test) to use
torch.testing.assert_close with a slightly larger tolerance (e.g., atol=2e-5 or
rtol=1e-6) or, alternatively, explicitly disable the new quantizers before
producing out_test so the values match exactly; modify the assertion accordingly
to reference out_ref and out_test.

---

Outside diff comments:
In `@modelopt/torch/export/layer_utils.py`:
- Around line 968-995: The requantize_resmooth_fused_llm_layers flow fails for
new fused expert types because requantize_resmooth_fused_llm_layers still relies
on get_experts_list (which only knows Mixtral/Qwen) and raises
NotImplementedError before _export_fused_experts is reached; update
requantize_resmooth_fused_llm_layers (or get_experts_list) to use the same
structural detection as get_expert_linear_names() (e.g., check hasattr(module,
"experts") and hasattr(module.experts, "gate_up_proj_weight_quantizers") or call
get_expert_linear_names(module)) as a fallback path so fused experts (DeepSeek,
Jamba, OLMoE, etc.) are handled generically instead of raising
NotImplementedError, or bypass the old model-type branch entirely and directly
dispatch to _export_fused_experts() when the structural checks match.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 06b7ed0f-85fc-4c97-a13b-76a86a374933

📥 Commits

Reviewing files that changed from the base of the PR and between 80d2f02 and fa93aec.

📒 Files selected for processing (5)
  • modelopt/torch/export/layer_utils.py
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py

Comment on lines +906 to +925
def _get_expert_idx_from_gate_up(self, weight: torch.Tensor) -> int:
"""Recover expert index from a ``gate_up_proj`` weight slice's storage offset.

When HF indexes ``gate_up_proj[idx]``, the result is a view sharing the
same underlying storage. The offset delta divided by the stride along
dim-0 gives the expert index.

The invariant breaks if the tensor is ``.contiguous()``-copied or
redistributed by certain distributed wrappers (FSDP2, tensor parallel).
"""
base_offset = self.gate_up_proj.storage_offset()
stride = self.gate_up_proj.stride(0)
if stride == 0:
return 0
idx = (weight.storage_offset() - base_offset) // stride
assert 0 <= idx < self.num_experts, (
f"Computed expert index {idx} out of range [0, {self.num_experts}). "
"This can happen if the weight was .contiguous()-copied or redistributed."
)
return idx
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
import torch

w = torch.randn(4, 8, 16)
base_offset = w.storage_offset()
stride0 = w.stride(0)

for i in range(w.shape[0]):
    copied = w[i].clone()  # simulates a materialized expert slice
    idx = (copied.storage_offset() - base_offset) // stride0
    print(
        f"expert={i} computed_idx={idx} "
        f"shares_storage={copied.untyped_storage().data_ptr() == w.untyped_storage().data_ptr()}"
    )
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 183


🏁 Script executed:

# First, locate and examine the file mentioned in the review
find . -name "huggingface.py" -path "*/quantization/plugins/*" | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 118


🏁 Script executed:

# Examine the actual code at lines 906-925
sed -n '906,925p' ./modelopt/torch/quantization/plugins/huggingface.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1045


🏁 Script executed:

# Check the context around this method - see the full class and where it's used
sed -n '882,930p' ./modelopt/torch/quantization/plugins/huggingface.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2592


🏁 Script executed:

# Check how _get_expert_idx_from_gate_up is called and what impact it has
rg "_get_expert_idx_from_gate_up" ./modelopt/torch/quantization/plugins/huggingface.py -A 3 -B 3

Repository: NVIDIA/Model-Optimizer

Length of output: 806


🏁 Script executed:

# Look for usage patterns and understand when this method is called
sed -n '920,980p' ./modelopt/torch/quantization/plugins/huggingface.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2473


🏁 Script executed:

# Check if there are any existing uses of data_ptr() or storage sharing checks in the codebase
rg "data_ptr\(\)|untyped_storage" ./modelopt/torch/ -l | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 152


🏁 Script executed:

# Verify the storage offset logic - let's analyze the mathematical relationship
# Check if there are any comments or documentation about storage offset semantics
rg "storage_offset|storage" ./modelopt/torch/quantization/plugins/huggingface.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 437


🏁 Script executed:

# Check how gate_up_proj is structured and accessed in typical usage
sed -n '870,895p' ./modelopt/torch/quantization/plugins/huggingface.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1382


🏁 Script executed:

# Let's understand the flow better - check how the method is integrated
sed -n '940,960p' ./modelopt/torch/quantization/plugins/huggingface.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1064


🏁 Script executed:

# Let's verify the assertion behavior with a simple logical analysis
# Check what happens in typical cases
cat > /tmp/analysis.py << 'EOF'
# Simulating the index computation without needing torch

# Example: gate_up_proj shape (num_experts=8, 2*4096, 4096)
# When accessing gate_up_proj[i], we get a view at storage offset: i * stride(0)
# stride(0) is the number of elements to skip to go from expert i to expert i+1

num_experts = 8
intermediate_dim = 4096
hidden_dim = 4096

# For a 3D tensor (num_experts, 2*intermediate_dim, hidden_dim)
# stride(0) = 2*intermediate_dim * hidden_dim = 2*4096*4096
stride_0 = 2 * intermediate_dim * hidden_dim
print(f"stride(0) = {stride_0}")

# Expert 0 view: storage_offset = 0
# Expert 1 view: storage_offset = stride_0
# Expert 2 view: storage_offset = 2*stride_0
# etc.

# If a tensor is cloned (copied), it gets its own storage with offset typically 0
# But when computing idx = (weight_offset - base_offset) // stride
# - If base_offset = 0 (expert 0 is accessed first)
# - And cloned tensor has offset = 0
# - Then idx = (0 - 0) // stride = 0  <- Valid index!
# - This would incorrectly route to expert 0's quantizer even if original was expert 5

# The current assertion only checks 0 <= idx < num_experts
# It doesn't check if the weight tensor shares storage

for expert_idx in range(num_experts):
    base_offset = 0  # or expert_idx * stride_0 depending on which expert is base
    weight_offset_if_shared = expert_idx * stride_0  # correct view
    weight_offset_if_cloned = 0  # cloned tensors usually start at 0
    
    computed_idx_shared = (weight_offset_if_shared - base_offset) // stride_0
    computed_idx_cloned = (weight_offset_if_cloned - base_offset) // stride_0
    
    print(f"Expert {expert_idx}: shared→{computed_idx_shared}, cloned→{computed_idx_cloned}")
    
    # If base_offset == 0, cloned tensor always computes to 0
    # If base_offset == expert_idx * stride, shared computes to expert_idx correctly
EOF
python /tmp/analysis.py

Repository: NVIDIA/Model-Optimizer

Length of output: 318


🏁 Script executed:

# Let's verify that the codebase uses untyped_storage() or data_ptr() elsewhere
rg "\.data_ptr\(\)|untyped_storage" ./modelopt/torch/quantization/ -B 2 -A 2 | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if there are any tests that validate the storage offset behavior
find ./modelopt -name "*test*huggingface*" -o -name "*test*fused*expert*" | xargs ls -la 2>/dev/null | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1291


🏁 Script executed:

# Search for test files in the quantization directory
find ./modelopt/torch/quantization -name "*test*.py" | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search for any test files related to this module
find . -path "./tests/*" -name "*hugging*" -o -path "./tests/*" -name "*expert*" 2>/dev/null | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 187


🏁 Script executed:

# Check the test file for fused experts
cat ./tests/unit/torch/quantization/plugins/test_fused_experts.py

Repository: NVIDIA/Model-Optimizer

Length of output: 12539


🏁 Script executed:

# Check PyTorch version requirements and untyped_storage availability
grep -r "torch" ./pyproject.toml ./requirements.txt 2>/dev/null | grep -i version | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 48


Assert shared storage before trusting storage_offset().

A copied/materialized expert slice can still produce an in-range idx (typically 0), so the current range assertion does not catch the .contiguous()-copied or redistributed failure mode described in the docstring. This silently routes subsequent experts to expert 0's quantizers instead of failing fast.

Suggested guard
     def _get_expert_idx_from_gate_up(self, weight: torch.Tensor) -> int:
         """Recover expert index from a ``gate_up_proj`` weight slice's storage offset.
         ...
         """
         base_offset = self.gate_up_proj.storage_offset()
         stride = self.gate_up_proj.stride(0)
         if stride == 0:
             return 0
+        if weight.untyped_storage().data_ptr() != self.gate_up_proj.untyped_storage().data_ptr():
+            raise AssertionError(
+                "Expected a gate_up_proj view that shares storage with self.gate_up_proj. "
+                "Copied/materialized weights cannot be mapped back to a stable expert index."
+            )
         idx = (weight.storage_offset() - base_offset) // stride
         assert 0 <= idx < self.num_experts, (
             f"Computed expert index {idx} out of range [0, {self.num_experts}). "
             "This can happen if the weight was .contiguous()-copied or redistributed."
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 906 - 925,
The computed idx must only be trusted if the slice truly shares the underlying
storage with self.gate_up_proj; first check that weight.storage().data_ptr() (or
weight.storage() identity) equals self.gate_up_proj.storage().data_ptr() and
raise a clear error if it does not, instead of silently proceeding; then compute
idx from storage_offset() and stride as before and keep the existing range
assertion using _get_expert_idx_from_gate_up, gate_up_proj, storage_offset,
stride, and num_experts to locate and fix the code path.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unit/torch/quantization/plugins/test_fused_experts.py (1)

160-165: Consider using a pytest fixture for registry cleanup to ensure isolation on test failure.

The current manual cleanup approach at start and end of each test works but is fragile if a test fails mid-execution. A fixture with yield guarantees cleanup.

♻️ Proposed fixture-based cleanup
+@pytest.fixture(autouse=True)
+def cleanup_registry():
+    """Ensure registry is clean before and after each test."""
+    yield
+    # Cleanup after test completes (even on failure)
+    for mod_type in [_SyntheticFusedExperts, _SyntheticSparseMoeBlock]:
+        if QuantModuleRegistry.get(mod_type) is not None:
+            QuantModuleRegistry.unregister(mod_type)
+
 class TestQuantFusedExperts:
-    `@staticmethod`
-    def _cleanup_registry(mod_type):
-        if QuantModuleRegistry.get(mod_type) is not None:
-            QuantModuleRegistry.unregister(mod_type)

Then remove the explicit _cleanup_registry calls from each test method.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py` around lines 160
- 165, Replace the ad-hoc _cleanup_registry pattern with a pytest fixture that
yields and performs teardown to guarantee registry cleanup on test failures:
create a fixture (e.g., quant_module_registry_cleanup) that checks
QuantModuleRegistry.get(mod_type) before yield and calls
QuantModuleRegistry.unregister(mod_type) in the finally/teardown section, then
use that fixture in TestQuantFusedExperts tests and remove explicit calls to
TestQuantFusedExperts._cleanup_registry from each test method.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 160-165: Replace the ad-hoc _cleanup_registry pattern with a
pytest fixture that yields and performs teardown to guarantee registry cleanup
on test failures: create a fixture (e.g., quant_module_registry_cleanup) that
checks QuantModuleRegistry.get(mod_type) before yield and calls
QuantModuleRegistry.unregister(mod_type) in the finally/teardown section, then
use that fixture in TestQuantFusedExperts tests and remove explicit calls to
TestQuantFusedExperts._cleanup_registry from each test method.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 9d8f0410-80ad-49d4-bed2-3fb4f4022ae4

📥 Commits

Reviewing files that changed from the base of the PR and between fa93aec and fb81b00.

📒 Files selected for processing (1)
  • tests/unit/torch/quantization/plugins/test_fused_experts.py


def _setup(self):
n = self.num_experts
self.gate_up_proj_input_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The down stream frameworks uses shared input quantization scale across all experts. With that context, would we just use one input_quantizer here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - updated.

w = nn.Module()
w.input_quantizer = q
wrappers.append(w)
set_expert_quantizer_amax(modules=wrappers, quantizer_attrs=["input_quantizer"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a message somewhere else for the case with uncalibrated quantizers? So users can decide if they want to use more data.

Copy link
Copy Markdown
Contributor Author

@Edwardf0t1 Edwardf0t1 Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added warnings.warn() when a weight quantizer was never calibrated and falls back to weight-derived amax, telling users to "consider using more calibration data to activate all experts."

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

Test failing: AttributeError: 'MixtralBlockSparseTop2MLP' object has no attribute 'gate_proj'

@Edwardf0t1
Copy link
Copy Markdown
Contributor Author

Test failing: AttributeError: 'MixtralBlockSparseTop2MLP' object has no attribute 'gate_proj'

It should be fixed now, could you double check?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
modelopt/torch/quantization/plugins/huggingface.py (2)

907-926: ⚠️ Potential issue | 🟠 Major

Reject materialized slices before deriving idx.

Line 921 only checks that the computed index is in range. A copied/materialized gate_up_proj[idx] can still satisfy that check—typically by collapsing to expert 0—so this path silently reuses the wrong quantizers instead of failing fast. Guard on shared storage, and ideally on expert-boundary alignment, before computing idx.

🛠️ Suggested guard
     def _get_expert_idx_from_gate_up(self, weight: torch.Tensor) -> int:
         ...
         base_offset = self.gate_up_proj.storage_offset()
         stride = self.gate_up_proj.stride(0)
         if stride == 0:
             return 0
-        idx = (weight.storage_offset() - base_offset) // stride
+        if weight.untyped_storage().data_ptr() != self.gate_up_proj.untyped_storage().data_ptr():
+            raise AssertionError(
+                "Expected gate_up_proj[idx] to share storage with self.gate_up_proj."
+            )
+        delta = weight.storage_offset() - base_offset
+        if delta % stride != 0:
+            raise AssertionError("Expected gate_up_proj[idx] to start on an expert boundary.")
+        idx = delta // stride
In PyTorch, if an expert index is computed from `(weight.storage_offset() - base_offset) // stride`, can `weight.clone()` or `weight.contiguous()` still yield an in-range index even though the tensor no longer shares storage with the original parameter?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 907 - 926,
The _get_expert_idx_from_gate_up method should reject materialized/copy slices
before deriving idx: verify that weight shares the same underlying storage
object as self.gate_up_proj (e.g., compare storage/data_ptr) and that the offset
aligns with expert boundaries ((weight.storage_offset() - base_offset) is
divisible by stride) before computing idx; if either check fails (including the
stride==0 special case), raise a clear error instead of proceeding so we don't
silently reuse the wrong quantizers.

871-879: ⚠️ Potential issue | 🟠 Major

Auto-register only the exact eager fused-expert layout.

The detector still accepts any module with 3-D gate_up_proj/down_proj, num_experts, and act_fn. That is broad enough to register transposed/BMM-style containers or non-eager fused variants, while _QuantFusedExperts assumes [N, 2I, H] / [N, H, I], an even split, and two alternating F.linear calls. Because the registry is keyed by class, one false positive affects every instance of that type. Validate tensor orientation, num_experts consistency, and the resolved experts backend before registering. The shape[1] // 2 fallback should also reject odd widths instead of truncating.

In Hugging Face transformers 5.x fused MoE experts, what tensor layouts are used for eager vs batched_mm/grouped_mm implementations, and where is the active experts implementation exposed on the module or config?

Also applies to: 1458-1505, 1729-1735

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 871 - 879,
The current detector in _get_fused_expert_intermediate_dim and the fused-expert
registration is too permissive; restrict registration to the exact eager
fused-expert layout by validating that module.gate_up_proj and
module.gate_down_proj have the expected orientations (e.g., gate_up_proj.shape
== (num_experts, 2*I, H) and gate_down_proj.shape == (num_experts, H, I) for
eager layout), that num_experts on the module matches the leading dimension of
these tensors, that the intermediate dimension (shape[1]) is even (reject odd
widths rather than floor-dividing), and that the module’s resolved experts
backend/type (from config or exposed attribute) corresponds to the eager
implementation before returning or registering _QuantFusedExperts; update
_get_fused_expert_intermediate_dim to return None or raise when these
validations fail so only exact eager fused-expert modules are auto-registered.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 255-298: Add a new CPU regression test that mirrors
test_export_creates_per_expert_submodules but uses the legacy non-fused MoE
container (the test suite's non-fused Synthetic expert class), e.g., instantiate
the non-fused container, register it with QuantModuleRegistry, call
QuantModuleRegistry.convert(...) to get converted, run a small forward on CPU to
calibrate, call _export_fused_experts(converted, torch.float32), then assert
per-expert submodules (getattr(converted, str(idx))) exist and contain
gate_proj/up_proj/down_proj with expected shapes and that fused attributes like
gate_up_proj and gate_up_proj_weight_quantizers are removed, and finally
unregister the expert type from QuantModuleRegistry.

---

Duplicate comments:
In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 907-926: The _get_expert_idx_from_gate_up method should reject
materialized/copy slices before deriving idx: verify that weight shares the same
underlying storage object as self.gate_up_proj (e.g., compare storage/data_ptr)
and that the offset aligns with expert boundaries ((weight.storage_offset() -
base_offset) is divisible by stride) before computing idx; if either check fails
(including the stride==0 special case), raise a clear error instead of
proceeding so we don't silently reuse the wrong quantizers.
- Around line 871-879: The current detector in
_get_fused_expert_intermediate_dim and the fused-expert registration is too
permissive; restrict registration to the exact eager fused-expert layout by
validating that module.gate_up_proj and module.gate_down_proj have the expected
orientations (e.g., gate_up_proj.shape == (num_experts, 2*I, H) and
gate_down_proj.shape == (num_experts, H, I) for eager layout), that num_experts
on the module matches the leading dimension of these tensors, that the
intermediate dimension (shape[1]) is even (reject odd widths rather than
floor-dividing), and that the module’s resolved experts backend/type (from
config or exposed attribute) corresponds to the eager implementation before
returning or registering _QuantFusedExperts; update
_get_fused_expert_intermediate_dim to return None or raise when these
validations fail so only exact eager fused-expert modules are auto-registered.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 31275e25-b8d3-4777-baaa-ab979d03d640

📥 Commits

Reviewing files that changed from the base of the PR and between fb81b00 and 196748e.

📒 Files selected for processing (4)
  • modelopt/torch/export/layer_utils.py
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/export/layer_utils.py

Comment on lines +255 to +298
class TestExportFusedExperts:
def test_export_creates_per_expert_submodules(self):
"""_export_fused_experts should create per-expert submodules with standard naming."""
from modelopt.torch.export.moe_utils import _export_fused_experts

experts = _SyntheticFusedExperts()
expert_type = type(experts)

# Manually register and convert
if QuantModuleRegistry.get(expert_type) is None:
QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})(
_QuantFusedExperts
)
converted = QuantModuleRegistry.convert(experts)

# Run a forward pass to calibrate (set amaxes)
seq_len = 16
hidden_states = torch.randn(seq_len, HIDDEN_DIM)
top_k_index = torch.randint(0, NUM_EXPERTS, (seq_len, TOP_K))
top_k_weights = torch.softmax(torch.randn(seq_len, TOP_K), dim=-1)
with torch.no_grad():
converted(hidden_states, top_k_index, top_k_weights)

_export_fused_experts(converted, torch.float16)

# Verify per-expert submodules exist
for idx in range(NUM_EXPERTS):
expert_mod = getattr(converted, str(idx), None)
assert expert_mod is not None, f"Missing expert submodule {idx}"
assert hasattr(expert_mod, "gate_proj"), f"Expert {idx} missing gate_proj"
assert hasattr(expert_mod, "up_proj"), f"Expert {idx} missing up_proj"
assert hasattr(expert_mod, "down_proj"), f"Expert {idx} missing down_proj"

assert expert_mod.gate_proj.weight.shape == (INTERMEDIATE_DIM, HIDDEN_DIM)
assert expert_mod.up_proj.weight.shape == (INTERMEDIATE_DIM, HIDDEN_DIM)
assert expert_mod.down_proj.weight.shape == (HIDDEN_DIM, INTERMEDIATE_DIM)

# Verify fused params are removed
assert not hasattr(converted, "gate_up_proj")
assert not hasattr(converted, "down_proj")
assert not hasattr(converted, "gate_up_proj_weight_quantizers")

if QuantModuleRegistry.get(expert_type) is not None:
QuantModuleRegistry.unregister(expert_type)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Please add a legacy non-fused export regression.

This export test only exercises the new fused layout. The reported backward-compat break is on the older per-expert expert-module path, so this suite still would not catch an export branch that assumes gate_proj/up_proj naming everywhere. Add one CPU regression for a non-fused MoE expert container as well.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py` around lines 255
- 298, Add a new CPU regression test that mirrors
test_export_creates_per_expert_submodules but uses the legacy non-fused MoE
container (the test suite's non-fused Synthetic expert class), e.g., instantiate
the non-fused container, register it with QuantModuleRegistry, call
QuantModuleRegistry.convert(...) to get converted, run a small forward on CPU to
calibrate, call _export_fused_experts(converted, torch.float32), then assert
per-expert submodules (getattr(converted, str(idx))) exist and contain
gate_proj/up_proj/down_proj with expected shapes and that fused attributes like
gate_up_proj and gate_up_proj_weight_quantizers are removed, and finally
unregister the expert type from QuantModuleRegistry.

@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/ptq-export-transformers5 branch from 196748e to 3e1a0fc Compare April 8, 2026 19:36
@Edwardf0t1 Edwardf0t1 requested review from a team as code owners April 8, 2026 19:36
@kevalmorabia97 kevalmorabia97 requested review from mxinO and removed request for a team, Fridah-nv, chadvoegele, kaix-nv, kinjalpatel27 and vishalpandya1990 April 8, 2026 19:39
Base automatically changed from kmorabi/bump-transformers-5.0 to main April 9, 2026 04:29
Edwardf0t1 and others added 9 commits April 9, 2026 04:36
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Update test imports and assertions after _is_sparse_moe_block was renamed
to _is_sparse_sequaential_moe_block in PR #975. Fused MoE blocks with
non-iterable experts are correctly not detected as sequential MoE blocks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/ptq-export-transformers5 branch from 3e1a0fc to 3cceca6 Compare April 9, 2026 04:36
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/quantization/plugins/huggingface.py (1)

926-945: ⚠️ Potential issue | 🟠 Major

Fail fast if the expert slice does not share storage with gate_up_proj.

A materialized/copied gate_up_proj[idx] can still yield an in-range idx, so the current range assertion does not actually catch the .contiguous() / redistribution case described in the docstring. In that case this silently routes later experts through the wrong quantizer instead of stopping.

Suggested fix
         base_offset = self.gate_up_proj.storage_offset()
         stride = self.gate_up_proj.stride(0)
         if stride == 0:
             return 0
+        if weight.untyped_storage().data_ptr() != self.gate_up_proj.untyped_storage().data_ptr():
+            raise AssertionError(
+                "Expected a gate_up_proj view that shares storage with self.gate_up_proj. "
+                "Copied/materialized weights cannot be mapped back to a stable expert index."
+            )
         idx = (weight.storage_offset() - base_offset) // stride
         assert 0 <= idx < self.num_experts, (
             f"Computed expert index {idx} out of range [0, {self.num_experts}). "
             "This can happen if the weight was .contiguous()-copied or redistributed."
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 926 - 945,
The method _get_expert_idx_from_gate_up must fail fast when the provided weight
slice does not share the same underlying storage as self.gate_up_proj; add a
check that weight.storage().data_ptr() (or weight.storage() ==
self.gate_up_proj.storage()) equals self.gate_up_proj.storage().data_ptr()
before computing the offset and if it differs raise a clear RuntimeError
indicating the slice is materialized/copied and cannot be mapped to an expert
index; keep the existing stride==0 handling and range assertion but perform the
storage identity check first (using gate_up_proj.storage()/data_ptr()) to detect
.contiguous() or redistributed tensors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 731-733: The current "break" inside the branch for
sub_module.experts with "gate_up_proj_weight_quantizers" prematurely exits the
loop and prevents shared fused-expert input quantizers
(gate_up_proj_input_quantizer / down_proj_input_quantizer) from being preserved,
causing fused experts to export without activation scales; replace the "break"
with a "continue" (or remove the break) so that _export_fused_experts still
handles weight amax fallback but the loop continues and shared input quantizers
are allowed to flow into _export_quantized_weight for proper input_scale
emission.

---

Duplicate comments:
In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 926-945: The method _get_expert_idx_from_gate_up must fail fast
when the provided weight slice does not share the same underlying storage as
self.gate_up_proj; add a check that weight.storage().data_ptr() (or
weight.storage() == self.gate_up_proj.storage()) equals
self.gate_up_proj.storage().data_ptr() before computing the offset and if it
differs raise a clear RuntimeError indicating the slice is materialized/copied
and cannot be mapped to an expert index; keep the existing stride==0 handling
and range assertion but perform the storage identity check first (using
gate_up_proj.storage()/data_ptr()) to detect .contiguous() or redistributed
tensors.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bf27fa93-e1f8-498f-9305-72259d11a9cd

📥 Commits

Reviewing files that changed from the base of the PR and between 196748e and 3cceca6.

📒 Files selected for processing (5)
  • modelopt/torch/export/layer_utils.py
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/export/layer_utils.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py

Comment on lines +731 to +733
elif hasattr(sub_module.experts, "gate_up_proj_weight_quantizers"):
# _QuantFusedExperts: amax fallback is handled in _export_fused_experts
break
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep the shared fused-expert input quantizers on the fallback path.

_export_fused_experts() only backfills missing amax for weight quantizers. The shared gate_up_proj_input_quantizer / down_proj_input_quantizer still flow into _export_quantized_weight(), which only emits input_scale when amax is already populated, so this break can export fused experts without activation scales.

Suggested fix
                 elif hasattr(sub_module.experts, "gate_up_proj_weight_quantizers"):
-                    # _QuantFusedExperts: amax fallback is handled in _export_fused_experts
+                    set_expert_quantizer_amax(
+                        modules=sub_module.experts,
+                        quantizer_attrs=[
+                            "gate_up_proj_input_quantizer",
+                            "down_proj_input_quantizer",
+                        ],
+                        device=sub_module.experts.gate_up_proj.device,
+                    )
                     break
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 731 - 733, The
current "break" inside the branch for sub_module.experts with
"gate_up_proj_weight_quantizers" prematurely exits the loop and prevents shared
fused-expert input quantizers (gate_up_proj_input_quantizer /
down_proj_input_quantizer) from being preserved, causing fused experts to export
without activation scales; replace the "break" with a "continue" (or remove the
break) so that _export_fused_experts still handles weight amax fallback but the
loop continues and shared input quantizers are allowed to flow into
_export_quantized_weight for proper input_scale emission.

Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update test matrix in https://github.com/NVIDIA/Model-Optimizer/blob/main/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py to include different tiny models - llms, moe, and vlms? Currently its all llama only

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants