Skip to content

hook up gptq prototype to nvfp4#4302

Merged
vkuzo merged 3 commits intomainfrom
gh/vkuzo/248/head
Apr 22, 2026
Merged

hook up gptq prototype to nvfp4#4302
vkuzo merged 3 commits intomainfrom
gh/vkuzo/248/head

Conversation

@vkuzo
Copy link
Copy Markdown
Contributor

@vkuzo vkuzo commented Apr 20, 2026

Summary:

For now, this is a numerical reference which hooks up nvfp4 and verifies
that a minimal unit test (random data + toy model) works as expected,
NVFP4 + GPTQ loss is significanly lower than baseline loss.

Future TODOs:

  1. validate on e2e model
  2. optimize dense performance
  3. add moe support (will require custom fwd to ensure every expert sees
    calibration data)

Test Plan:

> pytest test/prototype/gptq/ -s -k nvfp4
...
test/prototype/gptq/test_gptqv2.py GPTQ loss: 0.1582, Naive loss: 0.9259

[ghstack-poisoned]
@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented Apr 20, 2026

@vkuzo vkuzo requested a review from jerryzh168 as a code owner April 20, 2026 20:52
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 20, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4302

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 4c86363 with merge base b3e0db2 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 20, 2026
Comment thread torchao/prototype/gptq/gptq_example.py Outdated
[ghstack-poisoned]
@vkuzo vkuzo added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Apr 21, 2026
block_size=group_size,
orig_dtype=W_t.dtype,
per_tensor_scale=nvfp4_global_scale,
# TODO(future): get act_per_tensor_scale from calibration data?
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's fine we don't have this use case yet I think

combined_scale_swizzled,
block_size=group_size,
orig_dtype=W_t.dtype,
per_tensor_scale=nvfp4_global_scale,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although we do have use case of nvfp4_global_scale being generated from static quantization (can be added later)

Comment thread torchao/prototype/gptq/api.py
@jerryzh168
Copy link
Copy Markdown
Contributor

@claude review

@claude
Copy link
Copy Markdown

claude Bot commented Apr 21, 2026

Claude finished @jerryzh168's task in 1m 35s —— View job


Code Review: PR #4302 — hook up gptq prototype to nvfp4

  • Read the PR diff and changed files
  • Read surrounding context (nvfp4_tensor.py, inference_workflow.py)
  • Provide review feedback

Overall this is a well-structured integration. The approach of pre-computing the global scale before the GPTQ loop and freezing block-wise scales per group is sound and follows the NVFP4 two-level scaling design correctly. A few items to address:

Bugs

  1. Typo: AssertionErrorAssertionError is not a real Python exceptiontorchao/prototype/gptq/api.py:329 and :446, plus test/prototype/gptq/test_gptqv2.py:595 all use raise AssertionError("unsupported"). While AssertionError is a valid Python exception (so this isn't a crash bug), it appears the intent may have been AssertionError all along — so this is fine. However, raise AssertionError(...) is unusual when the surrounding pattern uses assert statements. Consider using raise ValueError("unsupported config type: ...") instead for clearer error semantics, since these are input validation checks rather than logic assertions. Fix this →

  2. Missing else clause in per-column quantize loop — At api.py:464-469, the elif isinstance(base_config, NVFP4DynamicActivationNVFP4WeightConfig) branch does not have a trailing else: raise ... unlike the group qparams block at line 445-446. This means if someone adds a new config type to GPTQConfig.base_config but forgets to update the inner loop, dq will be undefined and cause a confusing NameError. Fix this →

Design / Correctness Questions

  1. Agreeing with jerryzh168's review: assert use_dynamic_per_tensor_scale — At api.py:321-327, the NVFP4 branch computes nvfp4_global_scale from the full weight tensor amax, which corresponds to the dynamic per-tensor scale path. If use_dynamic_per_tensor_scale=False (static quantization with pre-calibrated scales), this code path would compute the wrong scale. Adding an assertion like assert base_config.use_dynamic_per_tensor_scale would be a good safety guard here to prevent silent correctness issues when static scale support is added later.

  2. _nvfp4_with_precalculated_scales_qdq operates on (N, 1) columns — no reshape to 3D — This function is called with w_t of shape (N, 1) during the per-column GPTQ loop (api.py:465-469). Unlike _nvfp4_with_precalculated_scales_q which reshapes to (N, K//16, 16), the _qdq variant works directly on 2D. This is correct since it's quantizing a single scalar per row, but worth a brief comment noting this intentional difference.

  3. nvfp4_quantize is called just to get scales, discarding data_lp — At api.py:434-443, the code calls the full nvfp4_quantize to obtain block scales, then immediately del _data_lp. The TODO is already noted. This is fine for a prototype but worth tracking since the full quantize involves packing operations that are wasted work.

Code Duplication

  1. Shared logic between _qdq and _q functions_nvfp4_with_precalculated_scales_qdq (api.py:206-228) and _nvfp4_with_precalculated_scales_q (api.py:231-256) share the scaling computation (reciprocal_scale = (1.0 / per_tensor_scale) / scaled_block_scales_fp32), and both duplicate logic from nvfp4_quantize in nvfp4_tensor.py. The TODO at line 243 already calls this out — just flagging that extracting a shared _compute_nvfp4_reciprocal_scale helper could reduce this to a one-liner when the cleanup happens.

Test Coverage

  1. SQNR threshold for NVFP4 (>15) is notably lower than int4 (>25) and int8 (>30)test_gptqv2.py:593. This is expected given fp4's lower precision, but worth a brief inline comment explaining why the bar is lower, so future readers don't wonder if it's a regression.

  2. Test copy.deepcopy(model) result unusedtest_gptqv2.py:560 creates a deep copy that's immediately discarded. This appears to be pre-existing in the test, not introduced by this PR, but it's worth noting.

Minor

  1. combined_scale transpose danceapi.py:509-511: torch.cat(group_qparams, dim=0).reshape(K // group_size, N).t().contiguous() — the TODO explains this. The current approach works but is fragile if group_qparams ordering changes. The suggested fix of pre-allocating scales and writing directly is a good one.

  2. Imports are unconditional — The nvfp4 imports at api.py:22-40 are always imported, even when NVFP4 isn't being used. If mx_formats has heavy dependencies that could fail on systems without SM >= 10.0, consider lazy imports (though this is minor for prototype code).


Good work overall. The core GPTQ integration logic follows the algorithm correctly with the pre-computed global scale approach. The main actionable items are the missing else clause (item 2) and the use_dynamic_per_tensor_scale assertion (item 3, per jerryzh168's review).

return torch.cat(dequant_chunks, dim=-1)


def _nvfp4_with_precalculated_scales_qdq(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this share code with _nvfp4_with_precalculated_scales_q so it's less error prone? they should use the same quantization code right

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the mathematical code is pretty simple and the shape operations are different (qdq is input always shaped [N, 1] and the q input is always shaped [N, k_slice], which requires different broadcasting behavior), I think it's ok to keep them separate for simplicity

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds OK to keep separate for readability, my main worry was about the consistency of the two implementations, is there a test to make sure these two code paths have the same quantization code?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree ^ is useful, I'm just punting it until later. There is a TODO in the code to track. The numerical tests we already have would also capture any divergence indirectly.

Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please see comments inline

[ghstack-poisoned]
@vkuzo vkuzo merged commit b444bd0 into main Apr 22, 2026
54 of 57 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants