Skip to content

extend GPTQ coverage to grouped_mm#4328

Merged
vkuzo merged 62 commits intomainfrom
gh/vkuzo/261/head
Apr 27, 2026
Merged

extend GPTQ coverage to grouped_mm#4328
vkuzo merged 62 commits intomainfrom
gh/vkuzo/261/head

Conversation

@vkuzo
Copy link
Copy Markdown
Contributor

@vkuzo vkuzo commented Apr 24, 2026

Summary:

Extend GPTQ for grouped_mm.

Punting the redefinition of counting batches vs tokens to a future PR.

Test Plan:

pytest test/prototype/gptq/test_gptqv2.py -s

vkuzo added 30 commits April 20, 2026 20:52
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added 6 commits April 23, 2026 18:46
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented Apr 24, 2026

@vkuzo vkuzo requested a review from jerryzh168 as a code owner April 24, 2026 13:25
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 24, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4328

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 24, 2026
vkuzo added a commit that referenced this pull request Apr 24, 2026
Summary:

Extend GPTQ for grouped_mm.

Punting the redefinition of counting batches vs tokens to a future PR.

Test Plan:

```
pytest test/prototype/gptq/test_gptqv2.py -s
```
ghstack-source-id: d93a8c3
ghstack-comment-id: 4313533590
Pull-Request: #4328
@vkuzo vkuzo added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Apr 24, 2026
vkuzo added 3 commits April 24, 2026 14:04
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Comment thread test/prototype/gptq/test_gptqv2.py
@jerryzh168
Copy link
Copy Markdown
Contributor

@claude review

@claude
Copy link
Copy Markdown

claude Bot commented Apr 24, 2026

Claude finished @jerryzh168's task in 1m 47s —— View job


Code Review: extend GPTQ coverage to grouped_mm

  • Read PR diff and changed files
  • Analyze implementation in torchao/prototype/gptq/observer.py
  • Analyze tests in test/prototype/gptq/test_gptqv2.py
  • Post review feedback

Overall this is a clean, well-structured extension. The implementation mirrors the existing update_3d / bmm pattern and the test is thorough. A few observations:

Implementation (torchao/prototype/gptq/observer.py)

1. Bug: prev_end not updated on skip path (line 119)

In update_3d_with_offs, when end == prev_end (expert has 0 tokens), the continue skips the prev_end = end assignment at line 119. This is actually correct only because end == prev_end means the assignment would be a no-op. However, the code is fragile — if someone refactors the skip condition (e.g., to end - prev_end < some_threshold), the bug would surface. Moving prev_end = end before the continue (or to the top of the loop body) would be more robust:

for e_idx in range(self.hessian.shape[0]):
    end = offs_cpu[e_idx]
    if end == prev_end:
        continue
    x_cur = x[prev_end:end]
    prev_end = end  # move here for clarity
    ...

This is a nit — the current code is correct for the current logic.

2. No validation of offs length vs hessian.shape[0] (line 111)

update_3d_with_offs iterates range(self.hessian.shape[0]) but indexes into offs_cpu without checking that len(offs_cpu) == self.hessian.shape[0]. A mismatch would cause a silent IndexError or, worse, silently process the wrong number of experts. Consider adding an assertion:

assert len(offs_cpu) == self.hessian.shape[0], \
    f"offs length {len(offs_cpu)} != num_experts {self.hessian.shape[0]}"

3. _grouped_mm dispatch assumes mat_b is always the observer (line 175)

The dispatch handler asserts isinstance(mat_b, GPTQObserverTensor) but doesn't handle the case where mat_a could be the observer tensor. This is fine for the current usage pattern (x @ weight.T), but worth noting. The assert makes this explicit, which is good.

4. offs passed positionally to func (line 177)

return func(mat_a, mat_b.hp_data, offs)

The _grouped_mm signature may also accept other kwargs (e.g., bias). Passing **kwargs through would be more forward-compatible:

return func(mat_a, mat_b.hp_data, offs, **kwargs)

This matches what other dispatch handlers in the codebase do for extensibility, though it's not strictly needed today.

Test (test/prototype/gptq/test_gptqv2.py)

5. Good coverage of zero-token experts, but expert 3 never sees 0 tokens

Echoing @jerryzh168's nit: the current m_per_group_list exercises the 0-token path for experts 0, 1, and 2, but never for expert 3 (the last expert). Adding a case like [3, 5, 4, 0] would cover the edge case where the last offset equals the previous offset (i.e., the final expert is empty). This is the most likely place for an off-by-one error in offset-based slicing.

6. Test structure is solid

The dual-path verification (3D _grouped_mm vs per-expert 2D F.linear) with bitwise torch.equal checks is a strong correctness guarantee. The additional expected_total_batches cross-check is a nice touch.

CI

The H100 test failures appear across multiple PRs in the stack and seem infrastructure-related (Docker exec failure), not caused by this PR.

Summary

The PR is correct and well-tested. The main actionable suggestion is adding a test case where the last expert sees 0 tokens (point 5), and optionally adding an offs length assertion (point 2). Everything else is minor/nit.

vkuzo added 3 commits April 27, 2026 12:06
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo vkuzo changed the base branch from gh/vkuzo/260/head to main April 27, 2026 12:13
@vkuzo vkuzo merged commit d3fe10e into main Apr 27, 2026
24 of 40 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants