Skip to content

Fix MXFP4 mlp_forward to handle 2D and 3D hidden_states shapes for multi-turn chat#40358

Open
PrathmeshAdsod wants to merge 8 commits intohuggingface:mainfrom
PrathmeshAdsod:fix-mxfp4-shape-handling-multi-turn-#40202
Open

Fix MXFP4 mlp_forward to handle 2D and 3D hidden_states shapes for multi-turn chat#40358
PrathmeshAdsod wants to merge 8 commits intohuggingface:mainfrom
PrathmeshAdsod:fix-mxfp4-shape-handling-multi-turn-#40202

Conversation

@PrathmeshAdsod
Copy link
Copy Markdown
Contributor

@PrathmeshAdsod PrathmeshAdsod commented Aug 21, 2025

Fixes #40202

This pull request fixes the AssertionError: assert num_stages >= 1 error that happens during multi-turn chat generation in GPT-OSS models using MXFP4 quantization. The issue arises because the input tensor shape to the MXFP4 Triton kernel is different between the first chat turn, which uses a 3D tensor with the full sequence, and later turns, which use a 2D tensor or a single token due to key-value caching. This discrepancy causes the Triton kernel’s internal autotuning to choose invalid optimization flags, leading to a crash.

The solution changes the mlp_forward function in mxfp4.py to clearly check if the hidden_states tensor is 3D or 2D. If it is 3D, the tensor is reshaped to 2D before routing and expert dispatch, and then it is reshaped back to 3D afterward. If it is already 2D, it is passed through unchanged. This guarantees that consistent and valid tensor shapes are sent to the Triton kernel for all generation turns.

Note- This PR solves the root cause and by this way can support all GPU's for multi-turn


Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Models:

Library:

Maintained examples (not research project or legacy):

  • PyTorch: See Models above and tag the person corresponding to the modality of the example.

@gante
Copy link
Copy Markdown
Contributor

gante commented Aug 22, 2025

mxfp4 -> perhaps @SunMarc ?

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM otherwise, do you mind adding a small test!

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: mxfp4

@PrathmeshAdsod
Copy link
Copy Markdown
Contributor Author

PrathmeshAdsod commented Aug 23, 2025

@ArthurZucker I added a test for mlp forward dimensionality. This covers both 2D passthrough and 3D flatten/restore. Our test case is running successfully. The thing is I am getting different issues and errors in CI even while rerunning same code....I am working on it.

Edit -

All checks passed....
I made a few extra commits while fixing style and CI issues. Let me know if you would prefer me to squash them into a cleaner history.

@cmp-nct
Copy link
Copy Markdown

cmp-nct commented Aug 24, 2025

possibly related bug: #40183 (comment)

routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)

routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw one thing I am not getting is that regardless, this line is equivalent (if 2d, you should still have batch size 1.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That line kept inside if statement that only runs when the original input was 3D. For the second turn of a chat, the input is 2D, so the if condition is false and that reshape line is completely skipped. The code correctly returns a 2D output to match the 2D input. Otherwise if not in if statement then can cause fake dimension issues. 2D is for second and subsequent terns

Comment on lines +292 to +296
is_3d = hidden_states.ndim == 3
if is_3d:
batch_size, seq_len, _ = hidden_states.shape
hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, I am not sure I understand the fix

Copy link
Copy Markdown
Contributor Author

@PrathmeshAdsod PrathmeshAdsod Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me clarify the differences between the original code and my fix.
This is considering my understanding of current code and Key-Value cache system for second and subsequent terns.

How the Original Code Handled the First Turn (The Successful Path)
The original code focused only on receiving a 3D tensor. It followed a strict pipeline that worked perfectly for this situation.

Here is the flow of the original code when it received the 3D tensor on the first turn:

Step 1: Receive 3D Tensor
The mlp_forward function gets the 3D hidden_states tensor, for example, (1, 50, 8192).

Step 2: Unpack Shape
The code runs batch_size, seq_len, _ = hidden_states.shape.
For a 3D tensor, this works well. It correctly identifies batch_size=1 and seq_len=50.

Step 3: Flatten Tensor
The code then executes hidden_states = hidden_states.reshape(...).
This flattens the 3D tensor into a 2D tensor, such as (50, 8192). This also works perfectly.

Step 4: Process with Experts
The new 2D tensor is passed to the MlpExperts layer. It processes the data and returns a 2D result.

Step 5: Reshape Output
The code takes the 2D result and uses the batch_size and seq_len from Step 2 to reshape it back into a 3D tensor. This step is successful too.
Conclusion for the first turn: The original code functioned without issues because its strict, step-by-step process matched the 3D input it received.

Why the Original Code Failed on the Second Turn
The crash may happened because the rigid approach could not handle the 2D tensor from the second turn.

1 seq_len KV cache
On the second turn, the 2D tensor (1, 8192) came in.
The code attempted to run its first operation, batch_size, seq_len, _ = hidden_states.shape.
This caused a error since it tried to unpack a shape with only two values into three variables. The program crashed.

How My Fix Corrects This
My fix keeps the successful logic from the first turn intact. It simply wraps the parts safe for 3D tensors inside the if is_3d: check.

For the first turn, is_3d is True, so the code runs the same successful steps for unpacking and reshaping as the original code did.
For the second turn, is_3d is False, so the code skips the unsafe unpacking and reshaping steps, avoiding the crash.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm you are saying that it crashed when trying to unpack which doesn't any sense with the original issue.

This caused a error since it tried to unpack a shape with only two values into three variables. The program crashed.

@require_torch
@patch("torch.cuda.device")
@patch("transformers.integrations.mxfp4.triton_kernels_hub", create=True)
def test_mlp_forward_dimensionality(self, mock_triton_hub, mock_cuda_device):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I should have specify but what we want for a test is a multiturn example that waas failing before the PR and is now fixed 😉

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ArthurZucker , I added the test case for multi-turn simulation. run_tests are passed, can you check our case is PASSED or SKIPPED. I am unable to see it.

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Aug 26, 2025

Can you check if this work @cmp-nct for your issue ?

@cmp-nct
Copy link
Copy Markdown

cmp-nct commented Aug 26, 2025

Can you check if this work @cmp-nct for your issue ?

I'm currently testing it, I removed the commented out assertion in the kernel and at first glance it looks like the problem is gone.
However, the problem I had was very random and not reliably reproduceable.
#Commenting the num_stages assertion out solved the issue until now.

I'll let you know if the crash happens again.
No negative things to report either, so all is running as it should with the PR.

Update: So far it looks good. I've not had an instance of the assertion anymore

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Aug 27, 2025

Hey @PrathmeshAdsod, on your branch and with this PR #40501, i'm getting this error when doing multi-turn

Traceback (most recent call last):
  File "/home/marc/anaconda3/envs/sglang/lib/python3.10/threading.py", line 1009, in _bootstrap_inner
    self.run()
  File "/home/marc/anaconda3/envs/sglang/lib/python3.10/threading.py", line 946, in run
    self._target(*self._args, **self._kwargs)
  File "/home/marc/transformers/src/transformers/commands/serving.py", line 962, in generate_with_cache
    generate_output = model.generate(**kwargs)
  File "/home/marc/anaconda3/envs/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/marc/transformers/src/transformers/generation/utils.py", line 2535, in generate
    result = self._sample(
  File "/home/marc/transformers/src/transformers/generation/utils.py", line 3483, in _sample
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  File "/home/marc/transformers/src/transformers/generation/utils.py", line 573, in prepare_inputs_for_generation
    inputs_embeds, input_ids = self._cache_dependant_input_preparation(
  File "/home/marc/transformers/src/transformers/generation/utils.py", line 479, in _cache_dependant_input_preparation
    or (cache_position[-1] >= input_ids.shape[1])  # Exception 3
IndexError: index -1 is out of bounds for dimension 0 with size 0

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Aug 27, 2025

@cmp-nct did you manage to reproduce the error with multi-turn chat ? or you are talking about 5090 compatibility ?

@cmp-nct
Copy link
Copy Markdown

cmp-nct commented Aug 27, 2025

@cmp-nct did you manage to reproduce the error with multi-turn chat ? or you are talking about 5090 compatibility ?

I frequently ran into the assertion error on my 5090 (never on 4090). I did not use multiturns.
What I did were batched inferences with various prompt length, generation length and batch sizes.
My result were quite randomly cases of that assertion exception.
It literally happened out of the blue, kept happening and after a while it was away again (but prompt combinations might have changed slightly).

My solution was to comment the assertion in the cached kernel python file #assert num_stages >= 1 - this had no negative consequences (inference works).
I've removed that comment since I used this PR and - so far - I've not seen the exception again.

@PrathmeshAdsod
Copy link
Copy Markdown
Contributor Author

@SunMarc can you please provide me your code which you used for multi-turn

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Aug 28, 2025

@SunMarc can you please provide me your code which you used for multi-turn

just do transformers serve and transformers chat localhost:8000 --model-name-or-path transformers chat localhost:8000 --model-name-or-path openai/gpt-oss-20b

#40504

evalstate added a commit to evalstate/transformers that referenced this pull request Apr 29, 2026
Applied from PR huggingface#40358 without direct merge because the PR branch contains follow-up CI/test commits and is based on older history.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Transformer serve gpt oss 20b cannot support multi-turn chat

5 participants