-
Notifications
You must be signed in to change notification settings - Fork 282
Fix a nvfp4 weight amax attribute issue during export #785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4b1bec5
9283649
97af258
1045532
4d54f55
eee6d2f
8c0eb8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -236,6 +236,31 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor: | |
| return scaling_factor | ||
|
|
||
|
|
||
| def _ensure_weight_quantizer_calibrated( | ||
| weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = "" | ||
| ) -> None: | ||
| """Calibrate weight quantizer if amax is not set. | ||
|
|
||
| This is a lazy calibration pattern used during export when weight quantizers | ||
| may not have been calibrated during the main calibration phase. | ||
|
|
||
| Args: | ||
| weight_quantizer: The weight quantizer to calibrate | ||
| weight: The weight tensor to use for calibration | ||
| module_name: Optional module name for better warning messages | ||
| """ | ||
| if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: | ||
Edwardf0t1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| warn( | ||
| f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. " | ||
| f"Computing amax from weights. This may occur if: " | ||
| f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size" | ||
| ) | ||
| weight_quantizer.reset_amax() | ||
| enable_stats_collection(weight_quantizer) | ||
| weight_quantizer(weight) | ||
| finish_stats_collection(weight_quantizer) | ||
|
|
||
|
Comment on lines
239
to
262
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: For vLLM NVFP4 (FP4) quantization, the activation-scale requirements are:
Sources: [1] [2] [3] 🌐 Web query:
💡 Result: In TensorRT-LLM, NVFP4 “input/activation scales” are not something you pass at serving time; they’re produced/embedded during quantization/calibration and then used by the engine at runtime.
Implication for serving: your client still sends normal (e.g., FP16/BF16) inputs; the TensorRT(-LLM) engine applies NVFP4 activation quantization internally using the calibrated Sources: [1] NVIDIA TensorRT “Working with Quantized Types” (NVFP4 block size + activation scheme) https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-quantized-types.html, [2] same doc section on Dynamic Double Quantization / 🌐 Web query:
💡 Result:
Sources: Citations:
Activation scales are required by deployment frameworks, but only if activation quantization is enabled. Both vLLM and TensorRT-LLM require offline-calibrated activation scales (global scales/globalSf) when NVFP4 activation quantization is used. However, for weight-only NVFP4 deployment, activation scales are not needed. This fix addresses weight quantizer calibration, which is sufficient for weight-only NVFP4. For MoE models where experts may not be activated during calibration, missing activation scales would be problematic only if activation quantization is required. If the Kimi-K2 checkpoint uses weight-only NVFP4, this limitation is not a concern. 🤖 Prompt for AI Agents |
||
|
|
||
| def get_activation_scaling_factor( | ||
| module: nn.Module, input_quantizer_name: str = "input_quantizer" | ||
| ) -> torch.Tensor: | ||
|
|
@@ -279,6 +304,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> | |
| QUANTIZATION_NVFP4_SVDQUANT, | ||
| QUANTIZATION_W4A8_NVFP4_FP8, | ||
| ]: | ||
| # Calibrate weight quantizer if amax is not set | ||
| module_name = f"{type(module).__name__}.{weight_name}" | ||
| _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) | ||
|
|
||
| if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: | ||
| # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. | ||
| # This is because the kernel dequantizes weight to fp8, which is in range 448. | ||
|
|
@@ -307,13 +336,26 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") | |
| if weight_quantizer is None: | ||
| return None | ||
|
|
||
| if get_quantization_format(module) in [ | ||
| quantization_format = get_quantization_format(module) | ||
|
|
||
| # Calibrate weight quantizer if amax is not set for all NVFP4 variants | ||
| if quantization_format in [ | ||
| QUANTIZATION_NVFP4, | ||
| QUANTIZATION_NVFP4_AWQ, | ||
| QUANTIZATION_NVFP4_SVDQUANT, | ||
| QUANTIZATION_W4A8_NVFP4_FP8, | ||
| ]: | ||
| weight = getattr(module, weight_name) | ||
| module_name = f"{type(module).__name__}.{weight_name}" | ||
| _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) | ||
|
|
||
| if quantization_format in [ | ||
| QUANTIZATION_NVFP4, | ||
| QUANTIZATION_NVFP4_AWQ, | ||
| QUANTIZATION_NVFP4_SVDQUANT, | ||
| ]: | ||
| return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) | ||
| elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8: | ||
| elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: | ||
| # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. | ||
| # This is because the kernel dequantizes weight to fp8, which is in range 448. | ||
| return weight_quantizer._amax.float() / 448.0 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.