Refactor Qwen3.5 MoE quantization to use _QuantFunctionalMixin#1170
Refactor Qwen3.5 MoE quantization to use _QuantFunctionalMixin#1170
Conversation
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
|
📝 WalkthroughWalkthroughThis change implements Qwen3.5 MoE expert quantization and export by updating expert classification, adding specialized expert splitting and quantization logic, integrating new export paths in the unified export pipeline, and replacing the decomposition-based quantization approach with a functional wrapper that intercepts linear operations. Changes
Sequence DiagramsequenceDiagram
participant Exporter as Unified Exporter
participant MoeModule as Qwen3.5 MoE Module
participant SplitLogic as Expert Splitting Logic
participant QuantLogic as Quantization Logic
participant Storage as Module Storage
Exporter->>MoeModule: Identify QuantQwen3_5MoeExperts
Exporter->>SplitLogic: Call _export_qwen35_experts()
SplitLogic->>MoeModule: Access fused gate_up_proj & down_proj
SplitLogic->>SplitLogic: Decompose fused weights per expert
loop For each expert slice
SplitLogic->>QuantLogic: Export quantized weight & scales
QuantLogic->>QuantLogic: Apply per-channel amax fallback
QuantLogic->>QuantLogic: Compute amax if uncalibrated
end
SplitLogic->>Storage: Register per-expert submodules
SplitLogic->>Storage: Remove fused parameters
SplitLogic->>Exporter: Return with per-expert structure
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
There was a problem hiding this comment.
🧹 Nitpick comments (2)
modelopt/torch/export/moe_utils.py (1)
105-117: Amax slicing logic is correct but inconsistent with line 130.The proportional slicing for per-channel amax is mathematically correct. However, line 117 sets
w_quantizer._amax(the internal attribute), while line 130 setsw_quantizer.amax(the property). Consider using the property setter consistently for proper validation:- w_quantizer._amax = amax[slice_start:slice_end].contiguous() + w_quantizer.amax = amax[slice_start:slice_end].contiguous()This ensures any property-level validation in
TensorQuantizer.amax.setteris applied uniformly.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/moe_utils.py` around lines 105 - 117, The per-channel amax slice currently assigns directly to the internal attribute w_quantizer._amax (in the block that checks hasattr(w_quantizer, "_amax")), which bypasses any validation in the TensorQuantizer.amax property; change this to assign via the property (e.g., set w_quantizer.amax = sliced_amax.contiguous()) instead of writing to _amax so the TensorQuantizer.amax.setter runs consistently with the later code that uses w_quantizer.amax.modelopt/torch/quantization/plugins/huggingface.py (1)
805-828: Consider thread-safety implications of the toggle mechanism.The toggle state (
_down_proj_linear,_current_expert_idx) is instance-level mutable state accessed during F.linear interception. If the same module instance is used concurrently (e.g., in data-parallel training without proper synchronization), the toggle could become inconsistent across threads.This is likely fine for typical inference/calibration workloads (single-threaded forward), but worth noting for future maintainers if concurrent usage becomes a requirement.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 805 - 828, The toggle state used in functionals_to_replace via the nested _quantized_linear (specifically instance fields _down_proj_linear and _current_expert_idx) is mutable and not thread-safe; replace the instance-level toggle with a thread-local or per-call state to avoid race conditions when F.linear is intercepted concurrently. Concretely, change _quantized_linear to use a threading.local() or local context object (created outside or on the stack) keyed per-thread/call to store the down-proj boolean and current expert index (instead of _down_proj_linear/_current_expert_idx), or protect access with a lightweight Lock around reads/writes; update uses of _get_expert_idx_from_gate_up, gate_up_proj_input_quantizers, gate_up_proj_weight_quantizers, down_proj_input_quantizers and down_proj_weight_quantizers to read/write the thread-local or locked state so concurrent forwards don’t clobber each other.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/export/moe_utils.py`:
- Around line 105-117: The per-channel amax slice currently assigns directly to
the internal attribute w_quantizer._amax (in the block that checks
hasattr(w_quantizer, "_amax")), which bypasses any validation in the
TensorQuantizer.amax property; change this to assign via the property (e.g., set
w_quantizer.amax = sliced_amax.contiguous()) instead of writing to _amax so the
TensorQuantizer.amax.setter runs consistently with the later code that uses
w_quantizer.amax.
In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 805-828: The toggle state used in functionals_to_replace via the
nested _quantized_linear (specifically instance fields _down_proj_linear and
_current_expert_idx) is mutable and not thread-safe; replace the instance-level
toggle with a thread-local or per-call state to avoid race conditions when
F.linear is intercepted concurrently. Concretely, change _quantized_linear to
use a threading.local() or local context object (created outside or on the
stack) keyed per-thread/call to store the down-proj boolean and current expert
index (instead of _down_proj_linear/_current_expert_idx), or protect access with
a lightweight Lock around reads/writes; update uses of
_get_expert_idx_from_gate_up, gate_up_proj_input_quantizers,
gate_up_proj_weight_quantizers, down_proj_input_quantizers and
down_proj_weight_quantizers to read/write the thread-local or locked state so
concurrent forwards don’t clobber each other.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 1a839b2a-9e97-4d74-a751-5dd420978867
📒 Files selected for processing (4)
modelopt/torch/export/layer_utils.pymodelopt/torch/export/moe_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/quantization/plugins/huggingface.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1170 +/- ##
==========================================
+ Coverage 74.27% 75.74% +1.47%
==========================================
Files 349 349
Lines 39846 39886 +40
==========================================
+ Hits 29594 30212 +618
+ Misses 10252 9674 -578
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Edwardf0t1
left a comment
There was a problem hiding this comment.
Review
Key Concerns
1. Fragile toggle-based state machine (huggingface.py)
The _quantized_linear closure uses a boolean toggle (_down_proj_linear) to distinguish gate_up vs down_proj calls:
self._down_proj_linear = not self._down_proj_linearThis assumes HF's forward calls F.linear exactly twice per expert in strict alternation. The comment acknowledges this, but:
- If HF changes the forward to add a third linear call (e.g., a shared expert gate), this silently misassigns quantizers.
- If an exception occurs mid-forward inside
super().forward(), the toggle is reset at the nextforward()call (good), but during gradient checkpointing or re-entrant autograd the toggle could get out of sync.
Consider validating with the weight shape or storage offset for both calls instead of only the first, so mismatches raise early rather than silently corrupting quantization.
2. Expert index recovery via storage offset (_get_expert_idx_from_gate_up)
return (weight.storage_offset() - base_offset) // strideThis is clever but brittle — it breaks if:
- The weight is
.contiguous()-copied (the docstring acknowledges this) - FSDP2, tensor parallel, or other distributed wrappers reshard/redistribute the parameter
torch.compilematerializes a new tensor
There's no runtime assertion that the computed index is in [0, num_experts). Adding a bounds check would catch silent corruption:
idx = (weight.storage_offset() - base_offset) // stride
assert 0 <= idx < self.num_experts, f"Invalid expert idx {idx}"
return idx3. break in export loop (unified_export_hf.py:726)
elif "QuantQwen3_5MoeExperts" in type(sub_module.experts).__name__:
break # exits the inner `for linear_name` loop; type check prevents re-entryUsing break to skip the inner loop is non-obvious and the comment says "type check prevents re-entry" but the prevention is really just that this branch is hit on the first iteration. If the iteration order of linear_name changes or a new name is added before the Qwen3.5 one, this could silently skip processing. A continue at the outer loop or restructuring to avoid the break would be clearer.
4. Export: copy.deepcopy on quantizer (moe_utils.py:106)
w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_srcdeepcopy on a TensorQuantizer can be expensive and may not correctly copy all internal state (e.g., registered hooks, CUDA state). Since only _amax needs to be sliced independently, consider cloning just the amax tensor instead of the entire quantizer.
5. Amax slicing math (moe_utils.py:109-118)
The proportional slicing logic:
slice_start = fused_start * amax_dim0 // fused_total
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_totalThis integer division assumes fused_total is always divisible by amax_dim0 (validated above), but the slice indices depend on fused_start also being aligned. For gate_proj (fused_start=0) this is fine, but for up_proj (fused_start=expert_dim), if expert_dim * amax_dim0 % fused_total != 0, the slicing would be wrong without error. Consider adding a check that slice_start * fused_total == fused_start * amax_dim0.
6. Minor: intermediate_size vs intermediate_dim (moe_utils.py:48-53)
The dual-attribute check is good for cross-version compatibility, but the quantization plugin (huggingface.py) doesn't seem to have the same fallback — it references self.intermediate_dim directly. These should be consistent.
Positive Aspects
- The core idea is sound —
_QuantFunctionalMixinavoids rewriting the HF forward, which reduces maintenance burden when upstream HF code changes. - Per-expert
ModuleListquantizers preserve calibration granularity. - Moving
Qwen3_5MoeSparseMoeBlockto the fused expert names group inlayer_utils.pyis a clean fix. - The export function is well-structured with clear separation of concerns.
Summary
The main risk is the toggle-based state machine for distinguishing linear calls — it's an implicit contract with HF's forward that has no runtime validation. Adding defensive assertions (expert index bounds, linear call count per expert) would significantly improve robustness. The storage-offset trick is clever but should also have a bounds check. The rest of the changes are clean.
meenchen
left a comment
There was a problem hiding this comment.
1. Dependencies — Open Questions
[QUESTION] vLLM export
huggingface.py:1491 (_QuantStep3p5MoeLinear docstring) explicitly notes vLLM requires stacked 3D scaling factors and that the add_module() per-expert approach (used here in _export_qwen35_experts) is not accepted by vLLM. Is vLLM export for Qwen3.5 intentionally out of scope, or does this need a separate path?
[QUESTION] FSDP2 compatibility
Has this been tested under FSDP2? Sharded parameters may have different storage layouts, breaking _get_expert_idx_from_gate_up during calibration (before export, where fsdp2_aware_weight_update protects things).
2. Design — Robustness
[SUGGESTION] Storage-offset expert index recovery — add bounds check
_get_expert_idx_from_gate_up relies on gate_up_proj[idx] always returning a contiguous-storage view. If the invariant breaks (distributed wrappers, .contiguous() copy), the index silently goes wrong. Add a bounds assertion:
idx = (weight.storage_offset() - base_offset) // stride
assert 0 <= idx < self.num_experts, f"Recovered expert idx {idx} out of range"[SUGGESTION] Toggle state machine tightly coupled to HF's forward
_down_proj_linear assumes F.linear is called exactly twice per expert in strict gate_up→down alternation. If a future HF release adds a third linear call, the toggle silently misaligns. Consider verifying weight shape (gate_up vs down dimensions) instead of a blind toggle as a defensive measure.
3. Issues — Code Quality
[SUGGESTION] break in _export_transformers_checkpoint — restructure control flow
The new branch at unified_export_hf.py:723 uses break inside the for linear_name loop but never uses linear_name. Cleaner to check the type before the loop:
if "QuantQwen3_5MoeExperts" in type(sub_module.experts).__name__:
continue # amax + export handled by _export_qwen35_experts
for linear_name in expert_linear_names:
...[SUGGESTION] copy.deepcopy on TensorQuantizer may carry unwanted state
In _export_qwen35_experts, deepcopy(w_quantizer_src) clones the entire quantizer including calibrators and hooks — only _amax slicing is needed. A lighter-weight approach (new quantizer + copy only the sliced _amax) would be more explicit and safer.
[QUESTION] Removal of __len__/__iter__/__getitem__
The old implementation made the experts module iterable. The isinstance(sub_module.experts, collections.abc.Iterable) fallback in _export_transformers_checkpoint:733 would no longer match — the new type-check short-circuits first, but is there any other downstream code that iterates over sub_module.experts?
Summary
_QuantQwen35MoeExpertsfromQuantModulewith a custom forward to_QuantFunctionalMixin, keeping the original HF forward unmodified (single fusedF.linear+ chunk instead of two separate matmuls per expert)ModuleLists with expert index recovery via storage offset, preserving per-expert calibration granularity_export_qwen35_expertsinmoe_utils.pyto split fused 3D params into per-expert named tensors at export time, reusing_export_quantized_weightfor all quantization formatsQwen3_5MoeSparseMoeBlockto the fusedgate_up_proj/down_projexpert linear names group inlayer_utils.pyTest plan
python -m pytest tests/unit/torch/quantization/plugins/test_sparse_moe.py -xpython -m pytest tests/gpu/torch/export/ -xexperts.{E}.gate_proj.weightconvention🤖 Generated with Claude Code
Summary by CodeRabbit
Bug Fixes
New Features
Improvements