Skip to content

[FSDP2] Cast model to uniform dtype before fully_shard to fix mixed-dtype AssertionError#3985

Open
roycho96 wants to merge 5 commits intohuggingface:mainfrom
roycho96:fix/mixed-precision
Open

[FSDP2] Cast model to uniform dtype before fully_shard to fix mixed-dtype AssertionError#3985
roycho96 wants to merge 5 commits intohuggingface:mainfrom
roycho96:fix/mixed-precision

Conversation

@roycho96
Copy link
Copy Markdown
Contributor

What does this PR do?

When mixed_precision is enabled, casts model parameters to uniform dtype before fully_shard() to prevent _init_mp_dtypes() AssertionError.

Problem

FSDP2's _init_mp_dtypes() requires uniform orig_dtype across all trainable parameters in a param group. With mixed dtypes, the first forward call crashes:

AssertionError: FSDP expects uniform original parameter dtype but got {torch.bfloat16, torch.float32}

FSDP2's fsdp2_prepare_model() currently passes the mixed-dtype model directly to fully_shard() without normalizing dtypes.

Fix

Cast all parameters to the mixed precision param_dtype before fully_shard(), after model_has_params4bit detection. Params4bit models are skipped to avoid destroying quantized weights.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc

@roycho96 roycho96 changed the title Fix/mixed precision [FSDP2] Cast model to uniform dtype before fully_shard to fix mixed-dtype AssertionError Mar 20, 2026
@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment !

Comment thread src/accelerate/utils/fsdp_utils.py Outdated
Comment on lines +669 to +682
# FSDP2 requires uniform original parameter dtype within each param group.
# Norm/embedding weights stored in fp32 while the rest is bf16/fp16, causing _init_mp_dtypes() AssertionError.
if accelerator.mixed_precision != "no" and not model_has_params4bit:
mp_policy = fsdp2_plugin.mixed_precision_policy
if mp_policy is not None and getattr(mp_policy, "param_dtype", None) is not None:
target_dtype = mp_policy.param_dtype
elif accelerator.mixed_precision == "bf16":
target_dtype = torch.bfloat16
elif accelerator.mixed_precision == "fp16":
target_dtype = torch.float16
else:
target_dtype = None
if target_dtype is not None:
model.to(target_dtype)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmm for mixed precision, do fsdp need the weights in fp16/bf16 and will keep a trace of the weights in fp32 to actually perform mixed precision ? or it is the other way around ? Can you share a reproducer also ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmm for mixed precision, do fsdp need the weights in fp16/bf16 and will keep a trace of the weights in fp32 to actually perform mixed precision ? or it is the other way around ? Can you share a reproducer also ?

you're right. proper mixed precision keeps fp32 master weights and casts to bf16/fp16 only during forward/backward compute. my current fix downcasts master weights, which collapses mixed precision into pure bf16/fp16 training and loses optimizer update precision.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is a reproducer

repro.py

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset

model = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceTB/SmolLM2-135M-Instruct", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
tokenizer.pad_token = tokenizer.eos_token

# Inspect dtypes before forcing any changes
import os
if int(os.environ.get("RANK", "0")) == 0:
    dtype_counts = {}
    for name, p in model.named_parameters():
        dtype_counts[p.dtype] = dtype_counts.get(p.dtype, 0) + 1
    print(f"[rank0] param dtype distribution (initial): {dtype_counts}")

# Force norm layers to fp32 to reliably trigger the mixed-dtype FSDP2 issue
for name, module in model.named_modules():
    if "norm" in name.lower():
        module.to(torch.float32)

if int(os.environ.get("RANK", "0")) == 0:
    dtype_counts = {}
    for name, p in model.named_parameters():
        dtype_counts[p.dtype] = dtype_counts.get(p.dtype, 0) + 1
    print(f"[rank0] param dtype distribution (after forcing norms fp32): {dtype_counts}")

dataset = Dataset.from_list([{"text": f"Hello {i}"} for i in range(32)])
dataset = dataset.map(
    lambda x: tokenizer(x["text"], truncation=True, max_length=32, padding="max_length"),
    batched=True,
    remove_columns=["text"],
)

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir="/tmp/repro",
        bf16=True,
        num_train_epochs=1,
        per_device_train_batch_size=2,
        logging_steps=1,
        save_strategy="no",
        report_to="none",
    ),
    train_dataset=dataset,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

fsdp2_config.yaml

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
  fsdp_version: 2
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
mixed_precision: bf16
num_processes: 2

accelerate launch --config_file fsdp2_config.yaml repro.py

@roycho96
Copy link
Copy Markdown
Contributor Author

roycho96 commented Apr 20, 2026

@SunMarc
I moved the existing (but no-op) fp32 upcasting logic to before fully_shard so it actually takes effect. This keeps fp32 master weights, letting MixedPrecisionPolicy.param_dtype handle the bf16/fp16 cast during compute.

Also, I added model_has_params4bit guard to skip the upcast for QLoRA.

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Apr 21, 2026

I moved the existing (but no-op) fp32 upcasting logic to before fully_shard so it actually takes effect. This keeps fp32 master weights, letting MixedPrecisionPolicy.param_dtype handle the bf16/fp16 cast during compute.

WDYM by no-op (it worked if the whole model was in bf16 no )? Also btw, i found this when looking for why they decided to have the upcast so late in the preparation of the model #2674 (comment) but it shouldn't be an issue for fsdpv2 to move this I think !

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks better !

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@roycho96
Copy link
Copy Markdown
Contributor Author

roycho96 commented Apr 21, 2026

I moved the existing (but no-op) fp32 upcasting logic to before fully_shard so it actually takes effect. This keeps fp32 master weights, letting MixedPrecisionPolicy.param_dtype handle the bf16/fp16 cast during compute.

WDYM by no-op (it worked if the whole model was in bf16 no )? Also btw, i found this when looking for why they decided to have the upcast so late in the preparation of the model #2674 (comment) but it shouldn't be an issue for fsdpv2 to move this I think !

I mean upcast actually works only in the FSDP1 path (accelerator.py:2020 uses param.data = param.data.to(torch.float32)). FSDP2 version was reassigns the local variable, so nothing ever got upcasted regardless of input dtype.

Btw, pushed a ruff fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants