Skip to content

fix: cast input dtype in LinearLoRA.forward() to prevent dtype mismatch#1622

Closed
stanley1208 wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
stanley1208:fix/lora-dtype-mismatch
Closed

fix: cast input dtype in LinearLoRA.forward() to prevent dtype mismatch#1622
stanley1208 wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
stanley1208:fix/lora-dtype-mismatch

Conversation

@stanley1208
Copy link
Copy Markdown
Contributor

What does this PR do?

Fix a dtype mismatch error in LinearLoRA.forward() that causes LoRA training to fail when the input tensor dtype differs from the model weight dtype (e.g., float32 input with bfloat16 weights).

Fixes #1540.

Root Cause

When model weights are loaded in bfloat16 (standard for LLMs), both the main linear layer (self.weight) and LoRA adapter layers (lora_A, lora_B) are in bfloat16. If the input x arrives as float32 (e.g., without autocast), F.linear raises:

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

Fix

Cast input x to match self.weight.dtype at the top of forward():

x = x.to(self.weight.dtype)

This single line handles all downstream operationsthe main linear, lora_A, lora_B, and DoRA pathswithout needing per-call dtype casts.

Verification
Reproduced and verified on A100 GPU:

original_linear = nn.Linear(768, 768, bias=False).to(dtype=torch.bfloat16, device='cuda')
lora_linear = LinearLoRA(original_linear, dim=8, alpha=32).to('cuda')

# Before fix: RuntimeError (dtype mismatch)
# After fix: works correctly
x = torch.randn(2, 128, 768, dtype=torch.float32, device='cuda')
output = lora_linear(x)  # SUCCESS: shape=[2, 128, 768], dtype=bfloat16

Both float32 and bfloat16 inputs produce correct outputs after the fix.

Signed-off-by: stanley1208 <stanley.mei08@gmail.com>
Made-with: Cursor
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 30, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

"""
# pylint: disable=C0115,C0116
# Cast input to match weight dtype to avoid dtype mismatch (e.g. float32 input with bfloat16 weights)
x = x.to(self.weight.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @stanley1208 , thanks for the PR, have you tried this with QLoRA by any chance?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akoumpa Thanks for the quick review! I haven't tested with QLoRA yet — I verified with standard LoRA on A100 with bfloat16 weights and float32 input.

For QLoRA, the lora_dtype parameter already handles setting the LoRA adapter dtype explicitly (lora.py line 108-109), and the super_fwd path (line 231-233) is used for quantized weights, which bypasses this cast. So the fix should be safe for QLoRA as well — but I'm happy to verify on Colab if you'd like. Let me know!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @stanley1208 for the reply!

To be honest I'm not sure this is the root cause for 1540, I feel there might be some issue related to initialization, where the whole model should be initialized with bfloat16, thus when it gets the token embeddings those should be in bfloat16 not float32.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, I've started CI let's see how that goes 🙏

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akoumpa That makes sense — my fix is a defensive guard, but the root cause may be in model initialization. Happy to investigate the init/checkpoint loading path if that would be helpful!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akoumpa I traced the init/checkpoint loading path. Here's what I found:

The dtype inconsistency likely originates in checkpointing.py line 396. For custom models (like GPT-OSS), _is_custom_model() returns True, which skips the _load_hf_checkpoint_preserving_dtype() path — the function that explicitly loads weights while preserving their checkpoint dtype (e.g., bfloat16).

The flow:

  1. _init_model() creates the model with local_torch_dtype() context (should be bfloat16)
  2. But during checkpoint loading, custom models don't go through _load_hf_checkpoint_preserving_dtype()
  3. The alternate loading path may not preserve dtype consistently
  4. After loading, embed_tokens may produce float32 hidden states
  5. These float32 hidden states flow into bfloat16 LoRA layers → dtype mismatch

The condition at checkpointing.py:396:

_is_bin_checkpoint(model_path) or (is_safetensors and not _is_custom_model(...))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome @stanley1208 , thanks for figuring this out so fast!

I wonder in this case what's the best way to proceed. From the one side we could add a similar _load_hf_checkpoint_preserving_dtype function, but that would need to:
(a) avoid allocating more memory than necessary, that is a problem for devices with unified memory (e.g., DGX Spark)
(b) handle things like dequantization, where the base checkpoint has quantized weights (e.g., mxfp4) but training happens in bf16
(c) weight remapping, where the base checkpoint has weights which are then translated to a different shape (as now happens in transformers v5).

Do you feel comfortable with the code to write a fix for this case? We are still in the process of adding the checkpoint test-suite, so please let me know.

CC @adil-a

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akoumpa Thanks! I'd love to take this on. The constraints you listed (memory, dequantization, weight remapping) make this more nuanced than a simple dtype cast — I want to make sure I get it right.

Could you point me to any existing examples of how custom models handle checkpoint loading currently? That would help me understand the expected flow before I propose a solution. Happy to start with a design sketch before writing code.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great @stanley1208 ,

For the custom models, I would look at https://github.com/NVIDIA-NeMo/Automodel/tree/main/nemo_automodel/components/models and pick gpt_oss/llama/mistral4 as example s, you can run something like

from nemo_automodel import NeMoAutoModelForCausalLM
model = NeMoAutoModelForCausalLM.from_pretrained("/path/to/gpt_oss_20b") # you can also use the hf id

This should internally trigger the custom model path (if you want to use the default you can do so by passing force_hf=True.

From there I would look into the _build_model method, _init_model and the apply_model_infrastructure (probably this one is the most important). The code got a bit complicated I must admit, so please let me know if you have any questions :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akoumpa Thanks for the pointers! I'll dig into _build_model, _init_model, and apply_model_infrastructure and come back with a design sketch. Will reach out if I have questions.

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 30, 2026

/ok to test 99465b0

@stanley1208
Copy link
Copy Markdown
Contributor Author

@akoumpa Here's my design sketch after tracing through _build_model_init_modelapply_model_infrastructureload_model:

Analysis

The divergence happens in checkpointing.py:396. HF models go through _load_hf_checkpoint_preserving_dtype() which loads weights from disk preserving checkpoint dtype. Custom models hit the DCP path (line 437+) where the comment says: "DCP copies into model's existing tensors; dtypes follow the model" — so if the model's init dtype is wrong, the loaded weights are wrong too.

Three options I considered

Option A (Recommended): Post-DCP dtype reconciliation

After dcp.load() + state_dict_adapter.from_hf(), cast tensors to the target training dtype. The fix would go after line 456 in load_model(), reading config.torch_dtype and casting any mismatched floating-point tensors to the target dtype.

  • Memory efficient (no extra state dict copy)
  • Works after dequantization (adapter runs first)
  • Orthogonal to weight remapping (only changes dtype, not shapes)

Option B: Remove not _is_custom_model() exclusion

Let custom models use _load_hf_checkpoint_preserving_dtype() too, adding _maybe_adapt_state_dict_from_hf() before loading. Simpler, but loads full state dict into memory (violates DGX Spark constraint).

Option C: Read checkpoint dtype metadata + targeted cast

Parse safetensors headers for per-tensor dtype without loading weights, then cast after DCP. Most precise, but more complex.

Recommendation

Option A for initial fix — it's simple, memory-efficient, handles dequantization correctly (runs after adapter), and is orthogonal to weight remapping. Can be enhanced to Option C later if per-tensor dtype preservation is needed.

Would this approach work? Happy to start implementing once you confirm the direction.

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 31, 2026

Thanks a lot for looking into it @stanley1208!

Option (A) sounds like a good plan to me. I expect there to be some complications due to FSDP2 but should be solvable. In particular, IIRC, the current sequence is init model on meta device --> wrap with fully_shard --> restore checkpoint so if restore checkpoint needs to cast the dtype, I'm not sure how that would work with fully_shard which also keeps track of the parameter dtype (-- mainly the whole reason preserve_dtype got added in the first place). But I would recommend to try and re-adjust course as necessary. :)

Thanks a lot again for tackling this, feel free to ping me when you have something I can help with.

@stanley1208
Copy link
Copy Markdown
Contributor Author

stanley1208 commented Mar 31, 2026

@akoumpa I found an additional detail while investigating. GPT-OSS sets _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] (model.py line 196), which means cast_model_to_dtype() intentionally keeps layer norms in float32.

The flow that triggers the bug:

  1. input_layernorm (float32) produces float32 output
  2. This flows into self_attn(float32_input)
  3. Inside attention, q_proj is LoRA-wrapped with bfloat16 weights
  4. F.linear(float32_x, bfloat16_weight) → dtype mismatch

This means the float32 input to LoRA layers comes from the layer norms being intentionally in float32, not from incorrect checkpoint loading. The defensive cast x = x.to(self.weight.dtype) in LinearLoRA.forward() may actually be the right fix for this specific case.

Should I still pursue the post-DCP dtype reconciliation (Option A) as a general hardening measure, or does this change the direction?

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 31, 2026

Hi @stanley1208, thanks a lot for digging into the bug, I did not anticipate this finding.

So I think there was a bug introduced in #1493 and not in the checkpointing, I apologize for the incorrect hint.

I feel the right fix would be to actually fix the output of the rmsnorm, because right now it does a lot of the computation in float32 not bfloat16.

You've done a great job digging into the details, and I think it would be great if you made an issue with your findings and then we can link the fix PR to the issue. That way we can keep track of this. I hope I'm not asking too much 🙇

@stanley1208
Copy link
Copy Markdown
Contributor Author

@akoumpa Not asking too much at all — happy to create the issue and fix PR! I'll document the findings and look into the RMSNorm output dtype fix. Will check #1493 to understand what changed.

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 31, 2026

Awesome @stanley1208 , thanks a ton, this is great work, feel free to ping me as necessary!

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 31, 2026

Issue was solved in #1633 . Thanks a ton @stanley1208 for investigating and solving the bug! Hope you keep contributing 🙇

@akoumpa akoumpa closed this Mar 31, 2026
@stanley1208
Copy link
Copy Markdown
Contributor Author

stanley1208 commented Apr 1, 2026

@akoumpa Thanks for the guidance and quick reviews throughout, I've learned a lot from this investigation! Will definitely keep contributing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

gpt-oss single-gpu PEFT fails with type-mismatch error

3 participants