Fix MXFP4 mlp_forward to handle 2D and 3D hidden_states shapes for multi-turn chat#40358
Conversation
… mlp_forward for multi-turn chat
|
mxfp4 -> perhaps @SunMarc ? |
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM otherwise, do you mind adding a small test!
…ion/mxfp4/test_mxfp4.py
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mxfp4 |
|
@ArthurZucker I added a test for mlp forward dimensionality. This covers both 2D passthrough and 3D flatten/restore. Our test case is running successfully. The thing is I am getting different issues and errors in CI even while rerunning same code....I am working on it. Edit -All checks passed.... |
|
possibly related bug: #40183 (comment) |
| routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) | ||
|
|
||
| routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) | ||
| routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) |
There was a problem hiding this comment.
btw one thing I am not getting is that regardless, this line is equivalent (if 2d, you should still have batch size 1.
There was a problem hiding this comment.
That line kept inside if statement that only runs when the original input was 3D. For the second turn of a chat, the input is 2D, so the if condition is false and that reshape line is completely skipped. The code correctly returns a 2D output to match the 2D input. Otherwise if not in if statement then can cause fake dimension issues. 2D is for second and subsequent terns
| is_3d = hidden_states.ndim == 3 | ||
| if is_3d: | ||
| batch_size, seq_len, _ = hidden_states.shape | ||
| hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) | ||
|
|
There was a problem hiding this comment.
same here, I am not sure I understand the fix
There was a problem hiding this comment.
Let me clarify the differences between the original code and my fix.
This is considering my understanding of current code and Key-Value cache system for second and subsequent terns.
How the Original Code Handled the First Turn (The Successful Path)
The original code focused only on receiving a 3D tensor. It followed a strict pipeline that worked perfectly for this situation.
Here is the flow of the original code when it received the 3D tensor on the first turn:
Step 1: Receive 3D Tensor
The mlp_forward function gets the 3D hidden_states tensor, for example, (1, 50, 8192).
Step 2: Unpack Shape
The code runs batch_size, seq_len, _ = hidden_states.shape.
For a 3D tensor, this works well. It correctly identifies batch_size=1 and seq_len=50.
Step 3: Flatten Tensor
The code then executes hidden_states = hidden_states.reshape(...).
This flattens the 3D tensor into a 2D tensor, such as (50, 8192). This also works perfectly.
Step 4: Process with Experts
The new 2D tensor is passed to the MlpExperts layer. It processes the data and returns a 2D result.
Step 5: Reshape Output
The code takes the 2D result and uses the batch_size and seq_len from Step 2 to reshape it back into a 3D tensor. This step is successful too.
Conclusion for the first turn: The original code functioned without issues because its strict, step-by-step process matched the 3D input it received.
Why the Original Code Failed on the Second Turn
The crash may happened because the rigid approach could not handle the 2D tensor from the second turn.
1 seq_len KV cache
On the second turn, the 2D tensor (1, 8192) came in.
The code attempted to run its first operation, batch_size, seq_len, _ = hidden_states.shape.
This caused a error since it tried to unpack a shape with only two values into three variables. The program crashed.
How My Fix Corrects This
My fix keeps the successful logic from the first turn intact. It simply wraps the parts safe for 3D tensors inside the if is_3d: check.
For the first turn, is_3d is True, so the code runs the same successful steps for unpacking and reshaping as the original code did.
For the second turn, is_3d is False, so the code skips the unsafe unpacking and reshaping steps, avoiding the crash.
There was a problem hiding this comment.
hmmm you are saying that it crashed when trying to unpack which doesn't any sense with the original issue.
This caused a error since it tried to unpack a shape with only two values into three variables. The program crashed.
| @require_torch | ||
| @patch("torch.cuda.device") | ||
| @patch("transformers.integrations.mxfp4.triton_kernels_hub", create=True) | ||
| def test_mlp_forward_dimensionality(self, mock_triton_hub, mock_cuda_device): |
There was a problem hiding this comment.
sorry I should have specify but what we want for a test is a multiturn example that waas failing before the PR and is now fixed 😉
There was a problem hiding this comment.
Hi @ArthurZucker , I added the test case for multi-turn simulation. run_tests are passed, can you check our case is PASSED or SKIPPED. I am unable to see it.
|
Can you check if this work @cmp-nct for your issue ? |
I'm currently testing it, I removed the commented out assertion in the kernel and at first glance it looks like the problem is gone. I'll let you know if the crash happens again. Update: So far it looks good. I've not had an instance of the assertion anymore |
|
Hey @PrathmeshAdsod, on your branch and with this PR #40501, i'm getting this error when doing multi-turn Traceback (most recent call last):
File "/home/marc/anaconda3/envs/sglang/lib/python3.10/threading.py", line 1009, in _bootstrap_inner
self.run()
File "/home/marc/anaconda3/envs/sglang/lib/python3.10/threading.py", line 946, in run
self._target(*self._args, **self._kwargs)
File "/home/marc/transformers/src/transformers/commands/serving.py", line 962, in generate_with_cache
generate_output = model.generate(**kwargs)
File "/home/marc/anaconda3/envs/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/marc/transformers/src/transformers/generation/utils.py", line 2535, in generate
result = self._sample(
File "/home/marc/transformers/src/transformers/generation/utils.py", line 3483, in _sample
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
File "/home/marc/transformers/src/transformers/generation/utils.py", line 573, in prepare_inputs_for_generation
inputs_embeds, input_ids = self._cache_dependant_input_preparation(
File "/home/marc/transformers/src/transformers/generation/utils.py", line 479, in _cache_dependant_input_preparation
or (cache_position[-1] >= input_ids.shape[1]) # Exception 3
IndexError: index -1 is out of bounds for dimension 0 with size 0 |
|
@cmp-nct did you manage to reproduce the error with multi-turn chat ? or you are talking about 5090 compatibility ? |
I frequently ran into the assertion error on my 5090 (never on 4090). I did not use multiturns. My solution was to comment the assertion in the cached kernel python file |
|
@SunMarc can you please provide me your code which you used for multi-turn |
Applied from PR huggingface#40358 without direct merge because the PR branch contains follow-up CI/test commits and is based on older history.
Fixes #40202
This pull request fixes the AssertionError: assert num_stages >= 1 error that happens during multi-turn chat generation in GPT-OSS models using MXFP4 quantization. The issue arises because the input tensor shape to the MXFP4 Triton kernel is different between the first chat turn, which uses a 3D tensor with the full sequence, and later turns, which use a 2D tensor or a single token due to key-value caching. This discrepancy causes the Triton kernel’s internal autotuning to choose invalid optimization flags, leading to a crash.
The solution changes the mlp_forward function in mxfp4.py to clearly check if the hidden_states tensor is 3D or 2D. If it is 3D, the tensor is reshaped to 2D before routing and expert dispatch, and then it is reshaped back to 3D afterward. If it is already 2D, it is passed through unchanged. This guarantees that consistent and valid tensor shapes are sent to the Triton kernel for all generation turns.
Note- This PR solves the root cause and by this way can support all GPU's for multi-turn
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
Models:
Library:
Maintained examples (not research project or legacy):