fix: cast input dtype in LinearLoRA.forward() to prevent dtype mismatch#1622
fix: cast input dtype in LinearLoRA.forward() to prevent dtype mismatch#1622stanley1208 wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: stanley1208 <stanley.mei08@gmail.com> Made-with: Cursor
| """ | ||
| # pylint: disable=C0115,C0116 | ||
| # Cast input to match weight dtype to avoid dtype mismatch (e.g. float32 input with bfloat16 weights) | ||
| x = x.to(self.weight.dtype) |
There was a problem hiding this comment.
Hi @stanley1208 , thanks for the PR, have you tried this with QLoRA by any chance?
There was a problem hiding this comment.
@akoumpa Thanks for the quick review! I haven't tested with QLoRA yet — I verified with standard LoRA on A100 with bfloat16 weights and float32 input.
For QLoRA, the lora_dtype parameter already handles setting the LoRA adapter dtype explicitly (lora.py line 108-109), and the super_fwd path (line 231-233) is used for quantized weights, which bypasses this cast. So the fix should be safe for QLoRA as well — but I'm happy to verify on Colab if you'd like. Let me know!
There was a problem hiding this comment.
Thanks @stanley1208 for the reply!
To be honest I'm not sure this is the root cause for 1540, I feel there might be some issue related to initialization, where the whole model should be initialized with bfloat16, thus when it gets the token embeddings those should be in bfloat16 not float32.
There was a problem hiding this comment.
In any case, I've started CI let's see how that goes 🙏
There was a problem hiding this comment.
@akoumpa That makes sense — my fix is a defensive guard, but the root cause may be in model initialization. Happy to investigate the init/checkpoint loading path if that would be helpful!
There was a problem hiding this comment.
@akoumpa I traced the init/checkpoint loading path. Here's what I found:
The dtype inconsistency likely originates in checkpointing.py line 396. For custom models (like GPT-OSS), _is_custom_model() returns True, which skips the _load_hf_checkpoint_preserving_dtype() path — the function that explicitly loads weights while preserving their checkpoint dtype (e.g., bfloat16).
The flow:
_init_model()creates the model withlocal_torch_dtype()context (should be bfloat16)- But during checkpoint loading, custom models don't go through
_load_hf_checkpoint_preserving_dtype() - The alternate loading path may not preserve dtype consistently
- After loading,
embed_tokensmay produce float32 hidden states - These float32 hidden states flow into bfloat16 LoRA layers → dtype mismatch
The condition at checkpointing.py:396:
_is_bin_checkpoint(model_path) or (is_safetensors and not _is_custom_model(...))There was a problem hiding this comment.
This is awesome @stanley1208 , thanks for figuring this out so fast!
I wonder in this case what's the best way to proceed. From the one side we could add a similar _load_hf_checkpoint_preserving_dtype function, but that would need to:
(a) avoid allocating more memory than necessary, that is a problem for devices with unified memory (e.g., DGX Spark)
(b) handle things like dequantization, where the base checkpoint has quantized weights (e.g., mxfp4) but training happens in bf16
(c) weight remapping, where the base checkpoint has weights which are then translated to a different shape (as now happens in transformers v5).
Do you feel comfortable with the code to write a fix for this case? We are still in the process of adding the checkpoint test-suite, so please let me know.
CC @adil-a
There was a problem hiding this comment.
@akoumpa Thanks! I'd love to take this on. The constraints you listed (memory, dequantization, weight remapping) make this more nuanced than a simple dtype cast — I want to make sure I get it right.
Could you point me to any existing examples of how custom models handle checkpoint loading currently? That would help me understand the expected flow before I propose a solution. Happy to start with a design sketch before writing code.
There was a problem hiding this comment.
This is great @stanley1208 ,
For the custom models, I would look at https://github.com/NVIDIA-NeMo/Automodel/tree/main/nemo_automodel/components/models and pick gpt_oss/llama/mistral4 as example s, you can run something like
from nemo_automodel import NeMoAutoModelForCausalLM
model = NeMoAutoModelForCausalLM.from_pretrained("/path/to/gpt_oss_20b") # you can also use the hf id
This should internally trigger the custom model path (if you want to use the default you can do so by passing force_hf=True.
From there I would look into the _build_model method, _init_model and the apply_model_infrastructure (probably this one is the most important). The code got a bit complicated I must admit, so please let me know if you have any questions :)
There was a problem hiding this comment.
@akoumpa Thanks for the pointers! I'll dig into _build_model, _init_model, and apply_model_infrastructure and come back with a design sketch. Will reach out if I have questions.
|
/ok to test 99465b0 |
|
@akoumpa Here's my design sketch after tracing through AnalysisThe divergence happens in Three options I consideredOption A (Recommended): Post-DCP dtype reconciliation After
Option B: Remove Let custom models use Option C: Read checkpoint dtype metadata + targeted cast Parse safetensors headers for per-tensor dtype without loading weights, then cast after DCP. Most precise, but more complex. RecommendationOption A for initial fix — it's simple, memory-efficient, handles dequantization correctly (runs after adapter), and is orthogonal to weight remapping. Can be enhanced to Option C later if per-tensor dtype preservation is needed. Would this approach work? Happy to start implementing once you confirm the direction. |
|
Thanks a lot for looking into it @stanley1208! Option (A) sounds like a good plan to me. I expect there to be some complications due to FSDP2 but should be solvable. In particular, IIRC, the current sequence is Thanks a lot again for tackling this, feel free to ping me when you have something I can help with. |
|
@akoumpa I found an additional detail while investigating. GPT-OSS sets The flow that triggers the bug:
This means the float32 input to LoRA layers comes from the layer norms being intentionally in float32, not from incorrect checkpoint loading. The defensive cast Should I still pursue the post-DCP dtype reconciliation (Option A) as a general hardening measure, or does this change the direction? |
|
Hi @stanley1208, thanks a lot for digging into the bug, I did not anticipate this finding. So I think there was a bug introduced in #1493 and not in the checkpointing, I apologize for the incorrect hint. I feel the right fix would be to actually fix the output of the rmsnorm, because right now it does a lot of the computation in float32 not bfloat16. You've done a great job digging into the details, and I think it would be great if you made an issue with your findings and then we can link the fix PR to the issue. That way we can keep track of this. I hope I'm not asking too much 🙇 |
|
Awesome @stanley1208 , thanks a ton, this is great work, feel free to ping me as necessary! |
|
Issue was solved in #1633 . Thanks a ton @stanley1208 for investigating and solving the bug! Hope you keep contributing 🙇 |
|
@akoumpa Thanks for the guidance and quick reviews throughout, I've learned a lot from this investigation! Will definitely keep contributing. |
What does this PR do?
Fix a dtype mismatch error in
LinearLoRA.forward()that causes LoRA training to fail when the input tensor dtype differs from the model weight dtype (e.g., float32 input with bfloat16 weights).Fixes #1540.
Root Cause
When model weights are loaded in bfloat16 (standard for LLMs), both the main linear layer (
self.weight) and LoRA adapter layers (lora_A,lora_B) are in bfloat16. If the inputxarrives as float32 (e.g., without autocast),F.linearraises:RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16
Fix
Cast input
xto matchself.weight.dtypeat the top offorward():