Skip to content

Fix EP + FSDP2: experts silently overwritten by rank-0 broadcast#45662

Open
AmineDiro wants to merge 5 commits intohuggingface:mainfrom
AmineDiro:fix-ep-fsdp-ignored-modules
Open

Fix EP + FSDP2: experts silently overwritten by rank-0 broadcast#45662
AmineDiro wants to merge 5 commits intohuggingface:mainfrom
AmineDiro:fix-ep-fsdp-ignored-modules

Conversation

@AmineDiro
Copy link
Copy Markdown
Member

What does this PR do?

Loading a MoE model with Expert Parallelism (distributed_config=DistributedConfig(enable_expert_parallel=True)) and then calling accelerator.prepare with FSDP2 silently loads wrong of the experts on ranks. The model trains, but on broken weights.

Tested on Qwen3-30B-A3B with 128 experts and EP=8. The from_pretrained correctly EP-shards experts: rank 0 holds experts 0–15, rank 1 holds 16–31, … Then in accelerate.utils.fsdp_utils.fsdp2_prepare_model (with cpu_ram_efficient_loading=True):

  1. Snapshot original_sd = model.state_dict(): captures per-rank-unique data.
  2. model.to("meta") : drops values.
  3. fully_shard(model): wraps params as DTensors on the FSDP mesh, assuming all ranks started with the same data.
  4. fsdp2_load_full_state_dict rank 0 calls dist.broadcast(full_param, src=0) for each param. For an EP-sharded param, rank 0's local tensor contains only experts 0–15. Every rank receives that data. After distribute_tensor, each rank holds a slice of rank 0's 16 experts.

The router still picks among 128, but all wrong weights.

Minimal repro

# repro_ep_fsdp.py — torchrun --nproc_per_node=8 repro_ep_fsdp.py
import os, torch, torch.distributed as dist
from transformers import AutoModelForCausalLM
from transformers.distributed import DistributedConfig
from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin

rank = int(os.environ["RANK"])
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-30B-A3B",
    dtype=torch.bfloat16,
    distributed_config=DistributedConfig(enable_expert_parallel=True),
)
dist.barrier()

# Per-rank-unique values: each rank holds 16 of the original 128 experts.
gu = model.model.layers[0].mlp.experts.gate_up_proj
local = gu.to_local() if hasattr(gu, "to_local") else gu
before = (local[0, 0, 0].item(), local[15, 0, 0].item())
print(f"[rank {rank}] BEFORE  expert0={before[0]:+.4e}  expert15={before[1]:+.4e}", flush=True)
dist.barrier()

# Wrap with FSDP2 the same way the Trainer does
plugin = FullyShardedDataParallelPlugin(
    fsdp_version=2, auto_wrap_policy="transformer_based_wrap",
    cpu_ram_efficient_loading=False,  # True triggers the destructive broadcast path
)
acc = Accelerator(fsdp_plugin=plugin)
optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=0.0)
model, optim = acc.prepare(model, optim)

gu = model.model.layers[0].mlp.experts.gate_up_proj
local = gu.to_local() if hasattr(gu, "to_local") else gu
after = (local[0, 0, 0].item(), local[15, 0, 0].item())
print(f"[rank {rank}] AFTER   expert0={after[0]:+.4e}  expert15={after[1]:+.4e}", flush=True)
dist.destroy_process_group()

Run on 1node, 8xH100s:

torchrun --nproc_per_node=8 repro_ep_fsdp.py
  • Without this PR: each rank's AFTER values match rank 0's BEFORE values (rank 0's 16 experts broadcast to everyone; ranks 1–7's data is lost).
  • With this PR: each rank's BEFORE/AFTER values match per rank all 8 unique slices preserved (128/128 experts retained).

The fix

  • Tell FSDP to skip the EP-sharded experts modules : fully_shard() doesn't auto-skip DTensors on a non-FSDP mesh. Also gate the existing ParallelismConfig(tp_size=...) auto-build on `not has_ep

  • Wrap EP-sharded params as DTensors (PreTrainedModel._wrap_ep_params_as_dtensor). Without this, after fully_shard() the rest of the model is DTensors but EP params stay plain nn.Parameter, optimizer crashes. we then use .to_local() in grouped_mm_experts_forward to get the local tensor, applied at the three weights (gate_up_proj, up_proj, down_proj).

Follow-up: batched_mm_experts_forward and sonicmoe_experts_forward need the same one-liner before they're EP-compatible. Kept out of scope here.

Verification

End-to-end SFT on Qwen3-30B-A3B, EP=8, single 8-GPU node, real trl/scripts/sft.py via accelerate launch --use_fsdp --fsdp_cpu_ram_efficient_loading false:

step 0 step 4
Before this PR loss=62, grad=nan loss=0, grad=nan
After this PR loss=8.88, grad=2.3 loss=8.80, grad=2.8

Who can review?

@ArthurZucker @IlyasMoutawwakil @3outeille

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really like the intention! 🤗
But we need to be careful and isolate changes, making sure to_locals() call are only run onces for perf etc

Comment thread src/transformers/modeling_utils.py Outdated
model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
model.set_use_kernels(use_kernels, kernel_config)

cls._wrap_ep_params_as_dtensor(model, device_mesh)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope we can't have that !

Copy link
Copy Markdown
Member Author

@AmineDiro AmineDiro Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, do you mean staticmethod or just inlining? or removing all together ? 😅

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +4379 to +4395
@staticmethod
def _wrap_ep_params_as_dtensor(model, device_mesh) -> None:
"""Wrap EP-sharded params (`grouped_gemm` style) as DTensors in-place.

Without this, the optimizer's foreach ops error with "mixed Tensor and DTensor"
against the FSDP-wrapped DTensor params on the rest of the model.
"""

if not model.has_ep:
return
plan = model.tp_plan
for name, p in list(model.named_parameters()):
if _get_parameter_tp_plan(parameter_name=name, tp_plan=plan, is_weight=True) != "grouped_gemm":
continue
parent, attr = get_module_from_name(model, name)
dt = DTensor.from_local(p.data, device_mesh, [Shard(0)], run_check=False)
setattr(parent, attr, nn.Parameter(dt, requires_grad=p.requires_grad))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that does not make sense to have here!
We should update the distribute_module and any parallel related code needs to be in distributed_xxx.py not in the general modeling utils which is already bloated as is

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +1378 to +1395
@property
def has_ep(self) -> bool:
"""Whether expert parallelism is enabled for this model."""
distributed_config = getattr(getattr(self, "config", None), "distributed_config", None)
return distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False)

@property
def ep_sharded_param_names(self) -> list[str]:
"""FQNs of parameters whose data is per-rank unique under EP sharding."""
if not self.has_ep:
return []
plan = self.tp_plan
return [
name
for name, _ in self.named_parameters()
if _get_parameter_tp_plan(parameter_name=name, tp_plan=plan, is_weight=True) == "grouped_gemm"
]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not modeling core to remove from here

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fine, tho as @3outeille said, this means anything that does not use our kernels will not work

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thats the tradeoff. in general this DTensor wrapping for TP params seems hacky and unnecessary.
Main issue is the optimizer but I think that can be solved in clearner way

@3outeille
Copy link
Copy Markdown
Member

just to be sure extra sure, can you train for 4 steps and try comparing the loss at each steps between a full 4 steps runvs 2 steps run -> save model and optimizers -> load models and optimizers -> 2 last steps please ?

@AmineDiro
Copy link
Copy Markdown
Member Author

AmineDiro commented Apr 28, 2026

just to be sure extra sure, can you train for 4 steps and try comparing the loss at each steps between a full 4 steps runvs 2 steps run -> save model and optimizers -> load models and optimizers -> 2 last steps please ?

@3outeille : Great idea !

step full4 (4 steps from scratch) save2 + load2 (2 steps → save → load → 2 steps) abs diff
0 1.608709 1.608709 0.000000
1 1.101468 1.102031 0.000564
2 0.810402 0.809934 0.000468
3 0.607736 0.606098 0.001638

script for testing: gist
max abs diff: 0.0016 within nondeterminism noise 👍🏼

@3outeille
Copy link
Copy Markdown
Member

just to be sure extra sure, can you train for 4 steps and try comparing the loss at each steps between a full 4 steps runvs 2 steps run -> save model and optimizers -> load models and optimizers -> 2 last steps please ?

@3outeille : Great idea !

step full4 (4 steps from scratch) save2 + load2 (2 steps → save → load → 2 steps) abs diff
0 1.608709 1.608709 0.000000
1 1.101468 1.102031 0.000564
2 0.810402 0.809934 0.000468
3 0.607736 0.606098 0.001638
script for testing: gist max abs diff: 0.0016 within nondeterminism noise 👍🏼

I found it a bit odd that loss is different at step 1 no ?

Move EP parameter DTensor wrapping from post-load model wrapping to
the tensor parallel layer's `post_shard_wrap` method, which applies
during parameter loading. This ensures DTensor wrapping happens at the
right time in the loading pipeline and removes duplicated logic.
@AmineDiro
Copy link
Copy Markdown
Member Author

@ArthurZucker :
In 9c712a5 I tried to refacto a bit :

  • moved EP-DTensor wrap into the TensorParallel as it should be the same. So now TensorParallelLayer has post_shard_wrap() method (no-op default), overridden in GroupedGemmParallel to wrap using DTensor.from_local(..., [Shard(0)]). Logic is that the wrap now happens per-param at the same path where sharding already runs 👍🏼
  • Calling tp_layer.post_shard_wrap(param) from core_model_loading, so all loading pathing should use it, every EP param goes through one core_model_loading if I understand correctly

I also reran the test @3outeille

step full4 save2 + load2 abs diff
0 1.608709 1.608709 0.000000
1 1.101315 1.102807 0.001491
2 0.808437 0.811338 0.002900
3 0.605437 0.603515 0.001922

max abs diff 0.0029 👍🏼

hopefully this is more aligned with the structure you have in mind. Thanks again for your time to review 🤗

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45662&sha=9c712a

@AmineDiro
Copy link
Copy Markdown
Member Author

AmineDiro commented Apr 28, 2026

just to be sure extra sure, can you train for 4 steps and try comparing the loss at each steps between a full 4 steps runvs 2 steps run -> save model and optimizers -> load models and optimizers -> 2 last steps please ?

@3outeille : Great idea !
step full4 (4 steps from scratch) save2 + load2 (2 steps → save → load → 2 steps) abs diff
0 1.608709 1.608709 0.000000
1 1.101468 1.102031 0.000564
2 0.810402 0.809934 0.000468
3 0.607736 0.606098 0.001638
script for testing: gist max abs diff: 0.0016 within nondeterminism noise 👍🏼

I found it a bit odd that loss is different at step 1 no ?

yes, but I think that's bf16 ULP stuff, because everything is the same it runs through the same code, same seed etc

EDIT : @3outeille went ahead and ran in fp32:
Same setup (Qwen3-30B-A3B, EP=8, FSDP2, seed=42, fixed batch), dtype=fp32, gradient checkpointing on, seq_len=512:

step full4 (fp32) save2 + load2 (fp32) abs diff
0 1.5371711254119873 1.5371711254119873 0
1 0.9394540190696716 0.9394540190696716 0
2 0.5880602598190308 0.5880602598190308 0
3 0.4099684059619904 0.4099684059619904 0

The world would be great if fp32 was fast 🥲

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants