-
Notifications
You must be signed in to change notification settings - Fork 33.1k
🚨 fix + tests dense & MoE TP all reduce (decoder only) #43722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9152e86
3234776
1c2684c
ec2ed1d
33ca330
e545ac1
fa78068
6e4d234
05fc1fa
ac291e8
819698c
d99f834
f0d0de1
c5cbdc8
381d773
b8901a8
73851ae
a2aa66a
1dca5f9
94d676c
959b46f
01c5774
9dbb634
8374298
989bd9a
8e46655
ce256ae
104f80d
14dca0c
3600fbe
fabde8a
20dee9a
73e77fb
0b95c64
9137fa2
646cbe3
4b32a6b
7253e80
1460762
2c4143a
f789395
8b4ed7b
4d2878a
d8cd533
b5fa932
76c904c
cfd92d7
f673207
6184091
cdf2f50
a6d4a32
46f61ad
2ac7fed
e5fa6fe
7c9f0d8
9975163
8ff9475
65815b6
fe2aa69
0b2e0e4
0be8add
1d3457b
7015912
a2e2bac
d376210
383ce33
782b366
3ac4b4f
577aa2d
c5cb269
ba233da
ec75498
782274d
4940a91
2976144
c1084ae
c0bd345
f18c79c
fef43aa
205f43c
9b844f9
36ffe3f
bf86273
a126ea7
cc6f26b
41223d1
9210a86
bac58b3
362ebe6
5b43d0d
f997d96
0c86c2a
8f840b4
0b40acb
e2b1eeb
4ba8784
a4b65a0
4c2dfb0
7f7c3cf
50c946d
274f643
51c2c83
983a3f6
78dd9e8
fb81843
f5ca722
4f04c56
662824b
edd81d9
9ad50c0
899238b
14d0ecc
00ba8e8
33a9567
f5136fd
af49ad0
7eb3263
6d81f36
cb9035e
abdc144
afa2812
084269a
ea0abf8
d01896e
0db56c8
1662b5b
c038773
59e9860
3cde599
2b5d952
df5f993
db23a99
c97dd50
95619cd
d0d351c
550b142
4d0e21d
2931680
5ce65f0
8153084
679beab
ef565c2
98bdba6
cc6fb29
07ee05b
4d831d3
a19c922
4d1f1cd
467d978
6c8be21
8cb24c1
5f750aa
1963db3
9330ee8
b0eba09
adab809
f8a23c8
27b16f6
4b2e724
c0b04fc
90b7077
773af8e
be0b732
d577c4e
c081a8b
bd96ba8
4bfbd70
aff64db
4fde282
7825c9f
39fbbaf
9215cac
6a80390
3cdbf54
f7b9aa5
b2fc24f
6863072
2c43f85
df40b73
0bb98fe
575fdbd
3e8e408
5f38642
c3e9f10
41e2373
1de9baa
de6d9aa
ebc29a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
|
ArthurZucker marked this conversation as resolved.
|
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to decide now how much work we do for the user:
We broken this when we removed
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can always make
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but if its the first op the replication comes from just TP, let me think 1 sec its just that its breaking + can be cumbersome but explicitness is good TBH |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -460,6 +460,7 @@ def backward(ctx, grad_output): | |
| device_mesh = ctx.device_mesh | ||
| if device_mesh.size() == 1: | ||
| return grad_output, None | ||
| grad_output = grad_output.contiguous() | ||
| dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group()) | ||
| return grad_output, None | ||
|
|
||
|
|
@@ -658,7 +659,7 @@ def shard_tensor( | |
| ) -> torch.Tensor: | ||
| raise NotImplementedError | ||
|
|
||
| def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: | ||
| def prepare_module_tp(self, module: nn.Module, device_mesh, **kwargs) -> nn.Module: | ||
| distribute_module( | ||
| module, | ||
| device_mesh, | ||
|
|
@@ -724,6 +725,86 @@ def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) - | |
| return tuple(shape) | ||
|
|
||
|
|
||
| class ReplicatedWithGradAllReduce(TensorParallelLayer): | ||
| """ | ||
| Replicated parameter with gradient all-reduce. | ||
|
|
||
| For parameters like q_norm/k_norm that sit between colwise and rowwise | ||
| layers. The parameter is replicated (not sharded), but its gradient | ||
| accumulates from local heads only in TP mode. This class registers a | ||
| backward hook to all-reduce the parameter gradient. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def _prepare_input_fn(mod, inputs, device_mesh): | ||
| return inputs | ||
|
|
||
| @staticmethod | ||
| def _prepare_output_fn(mod, outputs, device_mesh): | ||
| return outputs | ||
|
|
||
| def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None): | ||
| return param[...].to(device=device, dtype=dtype) | ||
|
|
||
| def prepare_module_tp(self, module, device_mesh, **kwargs): | ||
| # Use a module-level backward hook (not param.register_hook) because parameters are replaced during weight loading after this method runs. | ||
| # Module hooks survive parameter replacement. | ||
| def _backward_hook(mod, grad_input, grad_output, mesh=device_mesh): | ||
| for param in mod.parameters(): | ||
| if param.grad is not None: | ||
| all_reduce_forward(param.grad, mesh) | ||
|
|
||
| module.register_full_backward_hook(_backward_hook) | ||
|
|
||
|
|
||
| class MlaKvAProjParallel(TensorParallelLayer): | ||
| """ | ||
| For MLA attention used in DeepSeek-V2 style models (deepseek_v2, longcat_flash, glm_moe_dsa, glm4_moe_lite): | ||
| kv_a_proj_with_mqa output is [kv_lora_rank + qk_rope_head_dim] (can have different naming but important thing | ||
| to understand is that it is split) | ||
| Example below (from modeling_longcat_flash.py): | ||
|
|
||
| kv_a_proj_with_mqa | ||
| | | ||
| split | ||
| / \ | ||
| k_pass k_rot <-- "bypasses kv_b_proj" | ||
| | | (goes straight to attention, | ||
| kv_a_layernorm | never touches kv_b_proj) | ||
| | | | ||
| kv_b_proj | | ||
| (colwise) | | ||
| | | | ||
| k_pass k_rot | ||
| \\ / | ||
| cat | ||
| | | ||
| key_states | ||
|
|
||
| k_pass is passed to kv_b_proj (colwise) which has built-in all_reduce_backward so we don't have a partial gradient for it. | ||
| However, k_rot goes straight to attention, never touches kv_b_proj. So we need to average gradient across all ranks otherwise we only get gradient for one rank (partial gradient). | ||
|
3outeille marked this conversation as resolved.
|
||
| """ | ||
|
|
||
| def _prepare_output_fn(self, mod, output, device_mesh): | ||
| if not hasattr(mod.config, "qk_rope_head_dim"): | ||
| raise AttributeError( | ||
| f"Config for {type(mod).__name__} does not have `qk_rope_head_dim`. " | ||
| "MlaKvAProjParallel requires `qk_rope_head_dim` to be defined in the model config. " | ||
| "Please add it to the model's config or update the TP plan mapping." | ||
| ) | ||
| rope_dim = mod.config.qk_rope_head_dim | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should raise error if attr does not exist to say add it to the auto mapping |
||
| pass_output, rope_output = output.split([output.shape[-1] - rope_dim, rope_dim], dim=-1) | ||
| rope_output = all_reduce_backward(rope_output, device_mesh) | ||
| return torch.cat([pass_output, rope_output], dim=-1) | ||
|
|
||
| def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None): | ||
| return param[...].to(device=device, dtype=dtype) | ||
|
|
||
| def prepare_module_tp(self, module, device_mesh, config=None, **kwargs): | ||
| module.config = config | ||
| distribute_module(module, device_mesh, output_fn=self._prepare_output_fn) | ||
|
|
||
|
|
||
| class RowwiseParallel(TensorParallelLayer): | ||
| """ | ||
| Row-wise parallel: weight is sharded on dim -1 (input features). | ||
|
|
@@ -1087,6 +1168,29 @@ def shard_tensor( | |
| return param[...].to(device=device, dtype=dtype) | ||
|
|
||
|
|
||
| class MoeIdentityExpertParallel(TensorParallelLayer): | ||
| """ | ||
| TP class for zero/identity experts in MoE layers. | ||
|
|
||
| Under TP, the parent MoeTensorParalellExperts does all_reduce_forward (sum) | ||
| on the expert module output. Identity experts produce the same output on | ||
| every rank, so the sum gives world_size * output. This class divides the | ||
|
3outeille marked this conversation as resolved.
|
||
| input by world_size to compensate. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def _prepare_input_fn(mod, inputs, device_mesh): | ||
| input_tensor = inputs[0] if inputs else inputs | ||
| # TODO(fmom): when 2D-device mesh, need to select a //-ism axis to divide the input tensor by. | ||
| return input_tensor / device_mesh.size() | ||
|
|
||
| def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None): | ||
| return param[...].to(device=device, dtype=dtype) | ||
|
|
||
| def prepare_module_tp(self, module, device_mesh, **kwargs): | ||
| distribute_module(module, device_mesh, input_fn=self._prepare_input_fn) | ||
|
|
||
|
|
||
| class ParallelInterface(GeneralInterface): | ||
| # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if | ||
| # a new instance is created (in order to locally override a given entry) | ||
|
|
@@ -1103,6 +1207,9 @@ class ParallelInterface(GeneralInterface): | |
| "grouped_gemm": GroupedGemmParallel(), | ||
| "ep_router": RouterParallel(), | ||
| "moe_tp_experts": MoeTensorParalellExperts(), | ||
| "moe_identity_expert": MoeIdentityExpertParallel(), | ||
| "replicated_with_grad_allreduce": ReplicatedWithGradAllReduce(), | ||
| "mla_kv_a_proj": MlaKvAProjParallel(), | ||
| } | ||
| if is_torch_available() and _torch_distributed_available | ||
| else {} | ||
|
|
@@ -1120,6 +1227,8 @@ class ParallelInterface(GeneralInterface): | |
| "packed_rowwise": -1, | ||
| "embedding_rowwise": 0, | ||
| "sequence_parallel": None, | ||
| "replicated_with_grad_allreduce": None, | ||
| "mla_kv_a_proj": None, | ||
| } | ||
|
|
||
| # Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced) | ||
|
|
@@ -1132,6 +1241,8 @@ class ParallelInterface(GeneralInterface): | |
| "packed_rowwise": None, | ||
| "embedding_rowwise": None, | ||
| "sequence_parallel": None, | ||
| "replicated_with_grad_allreduce": None, | ||
| "mla_kv_a_proj": None, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -1258,13 +1369,14 @@ def add_tensor_parallel_hooks_to_module( | |
| if current_module_plan is not None: | ||
| tp_layer = ALL_PARALLEL_STYLES[current_module_plan] | ||
| try: | ||
| tp_layer.prepare_module_tp(module, device_mesh) | ||
| tp_layer.prepare_module_tp(module, device_mesh, config=model.config) | ||
| except NotImplementedError as e: | ||
| print( | ||
| f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}" | ||
| ) | ||
|
|
||
| module._hf_tp_plan = current_module_plan | ||
| module._hf_device_mesh = device_mesh | ||
| module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}" | ||
|
|
||
|
|
||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @SunMarc this is valid but happy if you can have a look
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SG ! |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -123,12 +123,19 @@ class DeepseekV2Config(PreTrainedConfig): | |
|
|
||
| base_model_tp_plan = { | ||
| "layers.*.self_attn.q_proj": "colwise", | ||
| "layers.*.self_attn.q_a_proj": "colwise", | ||
| "layers.*.self_attn.q_b_proj": "colwise", | ||
| "layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj", | ||
| "layers.*.self_attn.kv_b_proj": "colwise", | ||
| "layers.*.self_attn.o_proj": "rowwise", | ||
| "layers.*.mlp.experts.gate_up_proj": "packed_colwise", | ||
| "layers.*.mlp.experts.down_proj": "rowwise", | ||
| "layers.*.mlp.experts": "moe_tp_experts", | ||
| "layers.*.mlp.shared_experts.gate_proj": "colwise", | ||
| "layers.*.mlp.shared_experts.up_proj": "colwise", | ||
| "layers.*.mlp.shared_experts.down_proj": "rowwise", | ||
| "layers.*.mlp.gate_proj": "colwise", | ||
| "layers.*.mlp.up_proj": "colwise", | ||
|
Comment on lines
+132
to
+137
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IDK what's the most efficient to avoid having too many coms but LGTM otherwise
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. until we have EP working properly, i think it's okay to leave it this way |
||
| "layers.*.mlp.down_proj": "rowwise", | ||
| } | ||
| base_model_pp_plan = { | ||
| "embed_tokens": (["input_ids"], ["inputs_embeds"]), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.