[FSDP2] Cast model to uniform dtype before fully_shard to fix mixed-dtype AssertionError#3985
[FSDP2] Cast model to uniform dtype before fully_shard to fix mixed-dtype AssertionError#3985roycho96 wants to merge 5 commits intohuggingface:mainfrom
Conversation
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
| # FSDP2 requires uniform original parameter dtype within each param group. | ||
| # Norm/embedding weights stored in fp32 while the rest is bf16/fp16, causing _init_mp_dtypes() AssertionError. | ||
| if accelerator.mixed_precision != "no" and not model_has_params4bit: | ||
| mp_policy = fsdp2_plugin.mixed_precision_policy | ||
| if mp_policy is not None and getattr(mp_policy, "param_dtype", None) is not None: | ||
| target_dtype = mp_policy.param_dtype | ||
| elif accelerator.mixed_precision == "bf16": | ||
| target_dtype = torch.bfloat16 | ||
| elif accelerator.mixed_precision == "fp16": | ||
| target_dtype = torch.float16 | ||
| else: | ||
| target_dtype = None | ||
| if target_dtype is not None: | ||
| model.to(target_dtype) |
There was a problem hiding this comment.
hmmmm for mixed precision, do fsdp need the weights in fp16/bf16 and will keep a trace of the weights in fp32 to actually perform mixed precision ? or it is the other way around ? Can you share a reproducer also ?
There was a problem hiding this comment.
hmmmm for mixed precision, do fsdp need the weights in fp16/bf16 and will keep a trace of the weights in fp32 to actually perform mixed precision ? or it is the other way around ? Can you share a reproducer also ?
you're right. proper mixed precision keeps fp32 master weights and casts to bf16/fp16 only during forward/backward compute. my current fix downcasts master weights, which collapses mixed precision into pure bf16/fp16 training and loses optimizer update precision.
There was a problem hiding this comment.
here is a reproducer
repro.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-135M-Instruct", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
tokenizer.pad_token = tokenizer.eos_token
# Inspect dtypes before forcing any changes
import os
if int(os.environ.get("RANK", "0")) == 0:
dtype_counts = {}
for name, p in model.named_parameters():
dtype_counts[p.dtype] = dtype_counts.get(p.dtype, 0) + 1
print(f"[rank0] param dtype distribution (initial): {dtype_counts}")
# Force norm layers to fp32 to reliably trigger the mixed-dtype FSDP2 issue
for name, module in model.named_modules():
if "norm" in name.lower():
module.to(torch.float32)
if int(os.environ.get("RANK", "0")) == 0:
dtype_counts = {}
for name, p in model.named_parameters():
dtype_counts[p.dtype] = dtype_counts.get(p.dtype, 0) + 1
print(f"[rank0] param dtype distribution (after forcing norms fp32): {dtype_counts}")
dataset = Dataset.from_list([{"text": f"Hello {i}"} for i in range(32)])
dataset = dataset.map(
lambda x: tokenizer(x["text"], truncation=True, max_length=32, padding="max_length"),
batched=True,
remove_columns=["text"],
)
trainer = Trainer(
model=model,
args=TrainingArguments(
output_dir="/tmp/repro",
bf16=True,
num_train_epochs=1,
per_device_train_batch_size=2,
logging_steps=1,
save_strategy="no",
report_to="none",
),
train_dataset=dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()fsdp2_config.yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
mixed_precision: bf16
num_processes: 2accelerate launch --config_file fsdp2_config.yaml repro.py
|
@SunMarc Also, I added model_has_params4bit guard to skip the upcast for QLoRA. |
WDYM by no-op (it worked if the whole model was in bf16 no )? Also btw, i found this when looking for why they decided to have the upcast so late in the preparation of the model #2674 (comment) but it shouldn't be an issue for fsdpv2 to move this I think ! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I mean upcast actually works only in the FSDP1 path (accelerator.py:2020 uses param.data = param.data.to(torch.float32)). FSDP2 version was reassigns the local variable, so nothing ever got upcasted regardless of input dtype. Btw, pushed a ruff fix. |
What does this PR do?
When
mixed_precisionis enabled, casts model parameters to uniform dtype beforefully_shard()to prevent_init_mp_dtypes()AssertionError.Problem
FSDP2's
_init_mp_dtypes()requires uniformorig_dtypeacross all trainable parameters in a param group. With mixed dtypes, the first forward call crashes:FSDP2's
fsdp2_prepare_model()currently passes the mixed-dtype model directly tofully_shard()without normalizing dtypes.Fix
Cast all parameters to the mixed precision
param_dtypebeforefully_shard(), aftermodel_has_params4bitdetection. Params4bit models are skipped to avoid destroying quantized weights.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@SunMarc