Fix EP + FSDP2: experts silently overwritten by rank-0 broadcast#45662
Fix EP + FSDP2: experts silently overwritten by rank-0 broadcast#45662AmineDiro wants to merge 5 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Really like the intention! 🤗
But we need to be careful and isolate changes, making sure to_locals() call are only run onces for perf etc
| model.eval() # Set model in evaluation mode to deactivate Dropout modules by default | ||
| model.set_use_kernels(use_kernels, kernel_config) | ||
|
|
||
| cls._wrap_ep_params_as_dtensor(model, device_mesh) |
There was a problem hiding this comment.
nope we can't have that !
There was a problem hiding this comment.
sorry, do you mean staticmethod or just inlining? or removing all together ? 😅
| @staticmethod | ||
| def _wrap_ep_params_as_dtensor(model, device_mesh) -> None: | ||
| """Wrap EP-sharded params (`grouped_gemm` style) as DTensors in-place. | ||
|
|
||
| Without this, the optimizer's foreach ops error with "mixed Tensor and DTensor" | ||
| against the FSDP-wrapped DTensor params on the rest of the model. | ||
| """ | ||
|
|
||
| if not model.has_ep: | ||
| return | ||
| plan = model.tp_plan | ||
| for name, p in list(model.named_parameters()): | ||
| if _get_parameter_tp_plan(parameter_name=name, tp_plan=plan, is_weight=True) != "grouped_gemm": | ||
| continue | ||
| parent, attr = get_module_from_name(model, name) | ||
| dt = DTensor.from_local(p.data, device_mesh, [Shard(0)], run_check=False) | ||
| setattr(parent, attr, nn.Parameter(dt, requires_grad=p.requires_grad)) |
There was a problem hiding this comment.
that does not make sense to have here!
We should update the distribute_module and any parallel related code needs to be in distributed_xxx.py not in the general modeling utils which is already bloated as is
| @property | ||
| def has_ep(self) -> bool: | ||
| """Whether expert parallelism is enabled for this model.""" | ||
| distributed_config = getattr(getattr(self, "config", None), "distributed_config", None) | ||
| return distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False) | ||
|
|
||
| @property | ||
| def ep_sharded_param_names(self) -> list[str]: | ||
| """FQNs of parameters whose data is per-rank unique under EP sharding.""" | ||
| if not self.has_ep: | ||
| return [] | ||
| plan = self.tp_plan | ||
| return [ | ||
| name | ||
| for name, _ in self.named_parameters() | ||
| if _get_parameter_tp_plan(parameter_name=name, tp_plan=plan, is_weight=True) == "grouped_gemm" | ||
| ] | ||
|
|
There was a problem hiding this comment.
not modeling core to remove from here
There was a problem hiding this comment.
this is fine, tho as @3outeille said, this means anything that does not use our kernels will not work
There was a problem hiding this comment.
yes, thats the tradeoff. in general this DTensor wrapping for TP params seems hacky and unnecessary.
Main issue is the optimizer but I think that can be solved in clearner way
|
just to be sure extra sure, can you train for 4 steps and try comparing the loss at each steps between a |
@3outeille : Great idea !
script for testing: gist |
I found it a bit odd that loss is different at step 1 no ? |
Move EP parameter DTensor wrapping from post-load model wrapping to the tensor parallel layer's `post_shard_wrap` method, which applies during parameter loading. This ensures DTensor wrapping happens at the right time in the loading pipeline and removes duplicated logic.
|
@ArthurZucker :
I also reran the test @3outeille
max abs diff 0.0029 👍🏼 hopefully this is more aligned with the structure you have in mind. Thanks again for your time to review 🤗 |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45662&sha=9c712a |
yes, but I think that's bf16 ULP stuff, because everything is the same it runs through the same code, same seed etc EDIT : @3outeille went ahead and ran in fp32:
The world would be great if fp32 was fast 🥲 |
What does this PR do?
Loading a MoE model with Expert Parallelism (
distributed_config=DistributedConfig(enable_expert_parallel=True)) and then callingaccelerator.preparewith FSDP2 silently loads wrong of the experts on ranks. The model trains, but on broken weights.Tested on
Qwen3-30B-A3Bwith 128 experts and EP=8. Thefrom_pretrainedcorrectly EP-shards experts: rank 0 holds experts 0–15, rank 1 holds 16–31, … Then inaccelerate.utils.fsdp_utils.fsdp2_prepare_model(withcpu_ram_efficient_loading=True):original_sd = model.state_dict(): captures per-rank-unique data.model.to("meta"): drops values.fully_shard(model): wraps params as DTensors on the FSDP mesh, assuming all ranks started with the same data.fsdp2_load_full_state_dictrank 0 callsdist.broadcast(full_param, src=0)for each param. For an EP-sharded param, rank 0's local tensor contains only experts 0–15. Every rank receives that data. Afterdistribute_tensor, each rank holds a slice of rank 0's 16 experts.The router still picks among 128, but all wrong weights.
Minimal repro
Run on 1node, 8xH100s:
The fix
Tell FSDP to skip the EP-sharded experts modules :
fully_shard()doesn't auto-skip DTensors on a non-FSDP mesh. Also gate the existingParallelismConfig(tp_size=...)auto-build on `not has_epWrap EP-sharded params as DTensors (
PreTrainedModel._wrap_ep_params_as_dtensor). Without this, afterfully_shard()the rest of the model is DTensors but EP params stay plainnn.Parameter, optimizer crashes. we then use.to_local()ingrouped_mm_experts_forwardto get the local tensor, applied at the three weights (gate_up_proj,up_proj,down_proj).Verification
End-to-end SFT on Qwen3-30B-A3B, EP=8, single 8-GPU node, real
trl/scripts/sft.pyviaaccelerate launch --use_fsdp --fsdp_cpu_ram_efficient_loading false:Who can review?
@ArthurZucker @IlyasMoutawwakil @3outeille