Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
188 commits
Select commit Hold shift + click to select a range
9152e86
introducing test tensor parallel mixing to catch TP related error
3outeille Feb 3, 2026
3234776
Remove test file for tensor parallel functionality
3outeille Feb 3, 2026
1c2684c
Refactor dense and MoE test scripts for parallel execution and improv…
3outeille Feb 3, 2026
ec2ed1d
restore
ArthurZucker Feb 4, 2026
33ca330
add all reduce for ep
ArthurZucker Feb 4, 2026
e545ac1
fix init and bias sharding
ArthurZucker Feb 4, 2026
fa78068
fix finalize weight init
ArthurZucker Feb 4, 2026
6e4d234
add full stacktracing
ArthurZucker Feb 4, 2026
05fc1fa
fix
ArthurZucker Feb 4, 2026
ac291e8
add report to run tests
3outeille Feb 4, 2026
819698c
okay big improvement here
ArthurZucker Feb 4, 2026
d99f834
the only case shard index should be used is when we are acctually col…
ArthurZucker Feb 4, 2026
f0d0de1
more fixes
ArthurZucker Feb 4, 2026
c5cbdc8
fix EP forward gpt oss
ArthurZucker Feb 4, 2026
381d773
add test that trigger the weight converter or only dynamoc loading
3outeille Feb 4, 2026
b8901a8
Merge branch 'fix-ep' into fix-moe-ep
3outeille Feb 4, 2026
73851ae
Update test scripts to use new tensor parallel test keyword
3outeille Feb 4, 2026
a2aa66a
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 4, 2026
1dca5f9
cleaning + find_port + remove comments
3outeille Feb 4, 2026
94d676c
revert some shit
ArthurZucker Feb 4, 2026
959b46f
when you are stupid sometimes you really need a brain :) :) :) :)
ArthurZucker Feb 4, 2026
01c5774
fix TP
ArthurZucker Feb 4, 2026
9dbb634
Ok GPT oss is fixed now
ArthurZucker Feb 4, 2026
8374298
try to fix perms
ArthurZucker Feb 4, 2026
989bd9a
test only causal llm
3outeille Feb 4, 2026
8e46655
attempt to fix
ArthurZucker Feb 4, 2026
ce256ae
Merge branch 'fix-ep' into fix-moe-ep
3outeille Feb 4, 2026
104f80d
am I a doomer and AI is not that bad?
ArthurZucker Feb 5, 2026
14dca0c
fix
ArthurZucker Feb 5, 2026
3600fbe
it "passes" but the output is shit
ArthurZucker Feb 5, 2026
fabde8a
Merge branch 'fix-ep' of github.com:huggingface/transformers into fix-ep
ArthurZucker Feb 5, 2026
20dee9a
style my man
ArthurZucker Feb 5, 2026
73e77fb
Merge PR #43730 (fix-ep) into fix-moe-ep
3outeille Feb 5, 2026
0b95c64
outputs are gonna be giberish but at least the forward pass "works"
ArthurZucker Feb 5, 2026
9137fa2
Merge branch 'main' of github.com:huggingface/transformers into fix-ep
ArthurZucker Feb 5, 2026
646cbe3
dtyle
ArthurZucker Feb 5, 2026
4b32a6b
fix mixtral
ArthurZucker Feb 5, 2026
7253e80
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 5, 2026
1460762
okay shape fixes
ArthurZucker Feb 5, 2026
2c4143a
Merge branch 'pr-43730-latest' into fix-moe-ep
3outeille Feb 5, 2026
f789395
tensor idx is only for groupped gemm / EP
ArthurZucker Feb 5, 2026
8b4ed7b
fix gate_up shard
ArthurZucker Feb 5, 2026
4d2878a
Merge branch 'fix-ep' into fix-moe-ep
3outeille Feb 5, 2026
d8cd533
fix :)
ArthurZucker Feb 5, 2026
b5fa932
Merge branch 'pr-43730' into fix-moe-ep
3outeille Feb 5, 2026
76c904c
revert some EP changes that are breaking other stuff
ArthurZucker Feb 5, 2026
cfd92d7
style
ArthurZucker Feb 5, 2026
f673207
fix solar open tp
3outeille Feb 5, 2026
6184091
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 5, 2026
cdf2f50
Merge branch 'fix-ep' into fix-moe-ep
3outeille Feb 5, 2026
a6d4a32
trigger test on deepseek v3
3outeille Feb 5, 2026
46f61ad
fix glm4_moe tp
3outeille Feb 7, 2026
2ac7fed
fix glm4 moe lite tensor parallel
3outeille Feb 7, 2026
e5fa6fe
fix longcat and glm4_moe_lite by all reducing gradients of k_rot
3outeille Feb 7, 2026
7c9f0d8
fix ernie4_5_moe
3outeille Feb 7, 2026
9975163
fix qwen3 by all reduce grads of q_norm
3outeille Feb 7, 2026
8ff9475
fix deepseek v3 tp (need a constant dropout other different RNG + all…
3outeille Feb 7, 2026
65815b6
Rename ReplicatedInTP to ReplicatedWithGradAllReduce and update refer…
3outeille Feb 7, 2026
fe2aa69
fix minimax_m2
3outeille Feb 7, 2026
0b2e0e4
fix deepseek v2 for TP
3outeille Feb 7, 2026
0be8add
fix minimax
3outeille Feb 7, 2026
1d3457b
fix qwen3_next for TP
3outeille Feb 7, 2026
7015912
fix dots1 tp
3outeille Feb 7, 2026
a2e2bac
fix flex_olmo TP
3outeille Feb 7, 2026
d376210
fix qwen3 tp dense
3outeille Feb 7, 2026
383ce33
fix exaone4 tp
3outeille Feb 7, 2026
782b366
fix gemma3 tp
3outeille Feb 7, 2026
3ac4b4f
fix apterus TP
3outeille Feb 7, 2026
577aa2d
fix seed_oss tp by setting 0 to dropout
3outeille Feb 7, 2026
c5cb269
fix gemma3n for TP
3outeille Feb 8, 2026
ba233da
dropout set to 0 for test + gradient slicing depending on fused weigh…
3outeille Feb 8, 2026
ec75498
Merge branch 'main' of https://github.com/huggingface/transformers in…
3outeille Feb 8, 2026
782274d
make fixup + glm4 important fix on tp plan to avoid assigning wrong T…
3outeille Feb 8, 2026
4940a91
linting
3outeille Feb 8, 2026
2976144
remove shell scripts
3outeille Feb 8, 2026
c1084ae
make test tensor parallel triggering the CI
3outeille Feb 8, 2026
c0bd345
fix ci
3outeille Feb 8, 2026
f18c79c
fix ci
3outeille Feb 8, 2026
fef43aa
mark it as ep_plan
3outeille Feb 9, 2026
205f43c
add @require_torch_multi_accelerator
3outeille Feb 9, 2026
9b844f9
Merge branch 'main' into fix-moe-ep
3outeille Feb 9, 2026
36ffe3f
fix CI
3outeille Feb 9, 2026
bf86273
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 9, 2026
a126ea7
Merge branch 'main' into fix-moe-ep
3outeille Feb 10, 2026
cc6f26b
undo pr merge tensor parallel
3outeille Feb 10, 2026
41223d1
revert core model loading file
3outeille Feb 10, 2026
9210a86
revert modeling_utils file
3outeille Feb 10, 2026
bac58b3
small fix in modeling_utils
3outeille Feb 10, 2026
362ebe6
Update tensor parallel test configurations to enable tests by default…
3outeille Feb 10, 2026
5b43d0d
linting
3outeille Feb 10, 2026
f997d96
Reorganize imports in modeling_utils.py to maintain consistency
3outeille Feb 10, 2026
0c86c2a
fix qwen3_5_moe tp
3outeille Feb 10, 2026
8f840b4
Merge branch 'main' into fix-moe-ep
3outeille Feb 10, 2026
0b40acb
fix glm moe dsa tp
3outeille Feb 10, 2026
e2b1eeb
Merge branch 'main' into fix-moe-ep
3outeille Feb 10, 2026
4ba8784
fix qwen3_5 tp
3outeille Feb 10, 2026
a4b65a0
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 10, 2026
4c2dfb0
Add training_overfit_steps parameter to Gemma3nTextModelTest
3outeille Feb 10, 2026
7f7c3cf
fix 16 bits alignment
3outeille Feb 10, 2026
50c946d
Add WeightConverter for gate_up_proj and down_proj with 16 bytes alig…
3outeille Feb 10, 2026
274f643
Merge branch 'main' into fix-moe-ep
3outeille Feb 10, 2026
51c2c83
Add solar_open mapping with WeightConverter for gate_up_proj and down…
3outeille Feb 10, 2026
983a3f6
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 10, 2026
78dd9e8
Merge branch 'main' into fix-moe-ep
3outeille Feb 10, 2026
fb81843
Update hub metadata (#43892)
zucchini-nlp Feb 10, 2026
f5ca722
Add MlaKvAProjParallel layer for MLA attention and update TP plans
3outeille Feb 10, 2026
4f04c56
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 10, 2026
662824b
Merge branch 'main' into fix-moe-ep
3outeille Feb 10, 2026
edd81d9
Merge branch 'main' into fix-moe-ep
3outeille Feb 11, 2026
9ad50c0
fix doc
3outeille Feb 11, 2026
899238b
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 11, 2026
14d0ecc
force 16 Bytes Alignment
3outeille Feb 11, 2026
00ba8e8
Merge branch 'main' into fix-moe-ep
3outeille Feb 11, 2026
33a9567
fix slice tensor
3outeille Feb 11, 2026
f5136fd
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 11, 2026
af49ad0
Merge branch 'main' into fix-moe-ep
3outeille Feb 11, 2026
7eb3263
more doc
3outeille Feb 11, 2026
6d81f36
better abstraction for zero experts
3outeille Feb 11, 2026
cb9035e
linting
3outeille Feb 11, 2026
abdc144
refactor
3outeille Feb 11, 2026
afa2812
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 11, 2026
084269a
redudancy in tests
3outeille Feb 12, 2026
ea0abf8
simplify
3outeille Feb 12, 2026
d01896e
Merge branch 'main' into fix-moe-ep
3outeille Feb 12, 2026
0db56c8
Merge branch 'main' into fix-moe-ep
3outeille Feb 12, 2026
1662b5b
Merge branch 'main' into fix-moe-ep
3outeille Feb 12, 2026
c038773
revert
3outeille Feb 12, 2026
59e9860
Merge branch 'main' into fix-moe-ep
3outeille Feb 12, 2026
3cde599
Merge branch 'main' into fix-moe-ep
3outeille Feb 12, 2026
2b5d952
Merge branch 'main' into fix-moe-ep
3outeille Feb 12, 2026
df5f993
fix gemma2
3outeille Feb 12, 2026
db23a99
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 12, 2026
c97dd50
fix
3outeille Feb 12, 2026
95619cd
make tests work only on CPU
3outeille Feb 12, 2026
d0d351c
linting
3outeille Feb 12, 2026
550b142
skip tests for run_slow
3outeille Feb 12, 2026
4d0e21d
Merge branch 'main' into fix-moe-ep
3outeille Feb 13, 2026
2931680
cleaning
3outeille Feb 17, 2026
5ce65f0
cleaning
3outeille Feb 17, 2026
8153084
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 17, 2026
679beab
enhance doc on dynamic weight loading
3outeille Feb 17, 2026
ef565c2
add config instead of model for tp
3outeille Feb 17, 2026
98bdba6
more doc to tensor parallel for MlaKvAProjParallel
3outeille Feb 17, 2026
cc6fb29
Merge branch 'main' into fix-moe-ep
3outeille Feb 17, 2026
07ee05b
use -1 instead of self.num_heads, this way when TP is used, it can in…
3outeille Feb 17, 2026
4d831d3
Merge branch 'main' into fix-moe-ep
3outeille Feb 17, 2026
a19c922
fix modular glm_moe_dsa
3outeille Feb 17, 2026
4d1f1cd
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Feb 17, 2026
467d978
Merge branch 'main' into fix-moe-ep
3outeille Feb 17, 2026
6c8be21
Merge branch 'main' into fix-moe-ep
3outeille Feb 17, 2026
8cb24c1
Merge branch 'main' into fix-moe-ep
3outeille Feb 18, 2026
5f750aa
Merge branch 'main' of https://github.com/huggingface/transformers in…
3outeille Feb 25, 2026
1963db3
collect all gradient failure tests before stopping at first one
3outeille Feb 25, 2026
9330ee8
Merge branch 'main' into fix-moe-ep
3outeille Feb 25, 2026
b0eba09
Merge branch 'main' into fix-moe-ep
3outeille Feb 25, 2026
adab809
generate more max new tokens for tensor parallel test as models are s…
3outeille Feb 25, 2026
f8a23c8
Merge branch 'main' into fix-moe-ep
3outeille Feb 25, 2026
27b16f6
compare generated tokens for tensor parallel tests
3outeille Feb 25, 2026
4b2e724
use attr config as much as possible
3outeille Feb 25, 2026
c0b04fc
Merge branch 'main' into fix-moe-ep
3outeille Mar 2, 2026
90b7077
add TP + quantized tests
3outeille Mar 2, 2026
773af8e
raise error if attr does not exist to say add it to the auto mapping
3outeille Mar 2, 2026
be0b732
update doc
3outeille Mar 2, 2026
d577c4e
install torchao for tp + quantization tests
3outeille Mar 2, 2026
c081a8b
Merge branch 'main' into fix-moe-ep
3outeille Mar 2, 2026
bd96ba8
update doc
3outeille Mar 2, 2026
4bfbd70
update doc
3outeille Mar 2, 2026
aff64db
update doc
3outeille Mar 2, 2026
4fde282
Merge branch 'main' into fix-moe-ep
3outeille Mar 2, 2026
7825c9f
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Mar 2, 2026
39fbbaf
update doc
3outeille Mar 2, 2026
9215cac
udapte doc
3outeille Mar 2, 2026
6a80390
update doc
3outeille Mar 2, 2026
3cdbf54
Merge branch 'main' into fix-moe-ep
3outeille Mar 2, 2026
f7b9aa5
partially fix tp + quantization generation
3outeille Mar 3, 2026
b2fc24f
partially fix tp + quantize
3outeille Mar 3, 2026
6863072
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Mar 3, 2026
2c43f85
skipping some tp + quantized test for now
3outeille Mar 3, 2026
df40b73
guard torchao import for test_training_ci
3outeille Mar 3, 2026
0bb98fe
Merge branch 'main' into fix-moe-ep
3outeille Mar 3, 2026
575fdbd
Update src/transformers/models/longcat_flash/modular_longcat_flash.py
3outeille Mar 3, 2026
3e8e408
Merge branch 'main' into fix-moe-ep
3outeille Mar 3, 2026
5f38642
move file
3outeille Mar 3, 2026
c3e9f10
Merge branch 'main' into fix-moe-ep
3outeille Mar 3, 2026
41e2373
Merge branch 'fix-moe-ep' of https://github.com/huggingface/transform…
3outeille Mar 3, 2026
1de9baa
fix linting
3outeille Mar 3, 2026
de6d9aa
fix linting
3outeille Mar 3, 2026
ebc29a8
fix port conflict in test
3outeille Mar 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .circleci/create_circleci_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,15 @@ def job_name(self):
parallelism=6,
)

tensor_parallel_ci_job = CircleCIJob(
"tensor_parallel_ci",
Comment thread
ArthurZucker marked this conversation as resolved.
additional_env={"RUN_TENSOR_PARALLEL_TESTS": True},
docker_image=[{"image": "huggingface/transformers-torch-light"}],
install_steps=["uv pip install .", "uv pip install torchao"],
marker="is_tensor_parallel_test",
parallelism=6,
Comment thread
3outeille marked this conversation as resolved.
)

# We also include a `dummy.py` file in the files to be doc-tested to prevent edge case failure. Otherwise, the pytest
# hangs forever during test collection while showing `collecting 0 items / 21 errors`. (To see this, we have to remove
# the bash output redirection.)
Expand Down Expand Up @@ -358,7 +367,8 @@ def job_name(self):
REPO_UTIL_TESTS = [repo_utils_job]
DOC_TESTS = [doc_test_job]
TRAINING_CI_TESTS = [training_ci_job]
ALL_TESTS = REGULAR_TESTS + EXAMPLES_TESTS + PIPELINE_TESTS + REPO_UTIL_TESTS + DOC_TESTS + [custom_tokenizers_job] + [exotic_models_job] + TRAINING_CI_TESTS # fmt: skip
TENSOR_PARALLEL_CI_TESTS = [tensor_parallel_ci_job]
ALL_TESTS = REGULAR_TESTS + EXAMPLES_TESTS + PIPELINE_TESTS + REPO_UTIL_TESTS + DOC_TESTS + [custom_tokenizers_job] + [exotic_models_job] + TRAINING_CI_TESTS + TENSOR_PARALLEL_CI_TESTS # fmt: skip


def create_circleci_config(folder=None):
Expand Down
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def pytest_configure(config):
config.addinivalue_line("markers", "flash_attn_test: mark test which tests flash attention functionality")
config.addinivalue_line("markers", "flash_attn_3_test: mark test which tests flash attention 3 functionality")
config.addinivalue_line("markers", "training_ci: mark test for training CI validation")
config.addinivalue_line("markers", "tensor_parallel_ci: mark test for tensor parallel CI validation")

os.environ["DISABLE_SAFETENSORS_CONVERSION"] = "true"

Expand Down
422 changes: 416 additions & 6 deletions docs/source/en/weightconverter.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ markers = [
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
"generate: marks tests that use the GenerationTesterMixin",
"is_training_test: marks tests that use the TrainingTesterMixin (deselect with '-m \"not is_training_test\"')",
"is_tensor_parallel_test: marks tests that use the TensorParallelTesterMixin (deselect with '-m \"not is_tensor_parallel_test\"')",
]
log_cli = 1
log_cli_level = "WARNING"
Expand Down
34 changes: 31 additions & 3 deletions src/transformers/conversion_mapping.py
Comment thread
ArthurZucker marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,21 @@ def _build_checkpoint_conversion_mapping():
),
]

mapping["ernie4_5_moe"] = mapping["qwen2_moe"].copy()
mapping["ernie4_5_moe"] += [
WeightRenaming("mlp.moe_statics.e_score_correction_bias", "mlp.gate.moe_statics.e_score_correction_bias")
mapping["ernie4_5_moe"] = [
WeightRenaming("mlp.moe_statics.e_score_correction_bias", "mlp.gate.moe_statics.e_score_correction_bias"),
WeightConverter(
source_patterns=[
"mlp.experts.*.gate_proj.weight",
"mlp.experts.*.up_proj.weight",
],
target_patterns="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1), Force16BytesAlignment()],
),
WeightConverter(
source_patterns="mlp.experts.*.down_proj.weight",
target_patterns="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0), Force16BytesAlignment()],
),
]
mapping["minimax_m2"] = mapping["mixtral"].copy()
mapping["minimax_m2"] += [
Expand All @@ -363,6 +375,22 @@ def _build_checkpoint_conversion_mapping():
mapping["exaone_moe"] = mapping["qwen2_moe"].copy()
mapping["exaone_moe"] += [WeightRenaming("mlp.e_score_correction_bias", "mlp.gate.e_score_correction_bias")]

mapping["solar_open"] = [
WeightConverter(
source_patterns=[
"mlp.experts.*.gate_proj.weight",
"mlp.experts.*.up_proj.weight",
],
target_patterns="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1), Force16BytesAlignment()],
),
WeightConverter(
source_patterns="mlp.experts.*.down_proj.weight",
target_patterns="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0), Force16BytesAlignment()],
),
]

mapping["qwen3_5_moe_text"] = mapping["qwen3_5_text"].copy()
mapping["qwen3_5_moe_text"] += mapping["qwen2_moe"].copy()

Expand Down
116 changes: 114 additions & 2 deletions src/transformers/integrations/tensor_parallel.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to decide now how much work we do for the user:

  • colwise: should it split the input by itself if not split
  • rowise: same
  • rowise_split_input: when would you explicitly have that in that case

We broken this when we removed DTensor.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can always make rowise_split_input handled automatically inside rowise by checking shape mismatch but then it will hide the fact that an ops before has replicated the input which is less expressive. As you want, no strong opinion

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if its the first op the replication comes from just TP, let me think 1 sec its just that its breaking + can be cumbersome but explicitness is good TBH

Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def backward(ctx, grad_output):
device_mesh = ctx.device_mesh
if device_mesh.size() == 1:
return grad_output, None
grad_output = grad_output.contiguous()
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
return grad_output, None

Expand Down Expand Up @@ -658,7 +659,7 @@ def shard_tensor(
) -> torch.Tensor:
raise NotImplementedError

def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
def prepare_module_tp(self, module: nn.Module, device_mesh, **kwargs) -> nn.Module:
distribute_module(
module,
device_mesh,
Expand Down Expand Up @@ -724,6 +725,86 @@ def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -
return tuple(shape)


class ReplicatedWithGradAllReduce(TensorParallelLayer):
"""
Replicated parameter with gradient all-reduce.

For parameters like q_norm/k_norm that sit between colwise and rowwise
layers. The parameter is replicated (not sharded), but its gradient
accumulates from local heads only in TP mode. This class registers a
backward hook to all-reduce the parameter gradient.
"""

@staticmethod
def _prepare_input_fn(mod, inputs, device_mesh):
return inputs

@staticmethod
def _prepare_output_fn(mod, outputs, device_mesh):
return outputs

def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
return param[...].to(device=device, dtype=dtype)

def prepare_module_tp(self, module, device_mesh, **kwargs):
# Use a module-level backward hook (not param.register_hook) because parameters are replaced during weight loading after this method runs.
# Module hooks survive parameter replacement.
def _backward_hook(mod, grad_input, grad_output, mesh=device_mesh):
for param in mod.parameters():
if param.grad is not None:
all_reduce_forward(param.grad, mesh)

module.register_full_backward_hook(_backward_hook)


class MlaKvAProjParallel(TensorParallelLayer):
"""
For MLA attention used in DeepSeek-V2 style models (deepseek_v2, longcat_flash, glm_moe_dsa, glm4_moe_lite):
kv_a_proj_with_mqa output is [kv_lora_rank + qk_rope_head_dim] (can have different naming but important thing
to understand is that it is split)
Example below (from modeling_longcat_flash.py):

kv_a_proj_with_mqa
|
split
/ \
k_pass k_rot <-- "bypasses kv_b_proj"
| | (goes straight to attention,
kv_a_layernorm | never touches kv_b_proj)
| |
kv_b_proj |
(colwise) |
| |
k_pass k_rot
\\ /
cat
|
key_states

k_pass is passed to kv_b_proj (colwise) which has built-in all_reduce_backward so we don't have a partial gradient for it.
However, k_rot goes straight to attention, never touches kv_b_proj. So we need to average gradient across all ranks otherwise we only get gradient for one rank (partial gradient).
Comment thread
3outeille marked this conversation as resolved.
"""

def _prepare_output_fn(self, mod, output, device_mesh):
if not hasattr(mod.config, "qk_rope_head_dim"):
raise AttributeError(
f"Config for {type(mod).__name__} does not have `qk_rope_head_dim`. "
"MlaKvAProjParallel requires `qk_rope_head_dim` to be defined in the model config. "
"Please add it to the model's config or update the TP plan mapping."
)
rope_dim = mod.config.qk_rope_head_dim
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should raise error if attr does not exist to say add it to the auto mapping

pass_output, rope_output = output.split([output.shape[-1] - rope_dim, rope_dim], dim=-1)
rope_output = all_reduce_backward(rope_output, device_mesh)
return torch.cat([pass_output, rope_output], dim=-1)

def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
return param[...].to(device=device, dtype=dtype)

def prepare_module_tp(self, module, device_mesh, config=None, **kwargs):
module.config = config
distribute_module(module, device_mesh, output_fn=self._prepare_output_fn)


class RowwiseParallel(TensorParallelLayer):
"""
Row-wise parallel: weight is sharded on dim -1 (input features).
Expand Down Expand Up @@ -1087,6 +1168,29 @@ def shard_tensor(
return param[...].to(device=device, dtype=dtype)


class MoeIdentityExpertParallel(TensorParallelLayer):
"""
TP class for zero/identity experts in MoE layers.

Under TP, the parent MoeTensorParalellExperts does all_reduce_forward (sum)
on the expert module output. Identity experts produce the same output on
every rank, so the sum gives world_size * output. This class divides the
Comment thread
3outeille marked this conversation as resolved.
input by world_size to compensate.
"""

@staticmethod
def _prepare_input_fn(mod, inputs, device_mesh):
input_tensor = inputs[0] if inputs else inputs
# TODO(fmom): when 2D-device mesh, need to select a //-ism axis to divide the input tensor by.
return input_tensor / device_mesh.size()

def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
return param[...].to(device=device, dtype=dtype)

def prepare_module_tp(self, module, device_mesh, **kwargs):
distribute_module(module, device_mesh, input_fn=self._prepare_input_fn)


class ParallelInterface(GeneralInterface):
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given entry)
Expand All @@ -1103,6 +1207,9 @@ class ParallelInterface(GeneralInterface):
"grouped_gemm": GroupedGemmParallel(),
"ep_router": RouterParallel(),
"moe_tp_experts": MoeTensorParalellExperts(),
"moe_identity_expert": MoeIdentityExpertParallel(),
"replicated_with_grad_allreduce": ReplicatedWithGradAllReduce(),
"mla_kv_a_proj": MlaKvAProjParallel(),
}
if is_torch_available() and _torch_distributed_available
else {}
Expand All @@ -1120,6 +1227,8 @@ class ParallelInterface(GeneralInterface):
"packed_rowwise": -1,
"embedding_rowwise": 0,
"sequence_parallel": None,
"replicated_with_grad_allreduce": None,
"mla_kv_a_proj": None,
}

# Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced)
Expand All @@ -1132,6 +1241,8 @@ class ParallelInterface(GeneralInterface):
"packed_rowwise": None,
"embedding_rowwise": None,
"sequence_parallel": None,
"replicated_with_grad_allreduce": None,
"mla_kv_a_proj": None,
}


Expand Down Expand Up @@ -1258,13 +1369,14 @@ def add_tensor_parallel_hooks_to_module(
if current_module_plan is not None:
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
try:
tp_layer.prepare_module_tp(module, device_mesh)
tp_layer.prepare_module_tp(module, device_mesh, config=model.config)
except NotImplementedError as e:
print(
f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
)

module._hf_tp_plan = current_module_plan
module._hf_device_mesh = device_mesh
module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"


Expand Down
12 changes: 12 additions & 0 deletions src/transformers/integrations/torchao.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @SunMarc this is valid but happy if you can have a look

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG !

Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,21 @@ def convert(
quantize_(module, c, (lambda x, fqn: True))
missing_keys.discard(full_layer_name)
module._is_hf_initialized = True
# torchao quantizes weights into a module but some models access the weight directly
# (e.g. module.o_proj.weight). The _is_hf_initialized flag is set at the module
# level only, so we also set it on each parameter to prevent _init_weights from
# calling normal_() on already-quantized Float8Tensors.
for param in module.parameters(recurse=False):
param._is_hf_initialized = True
return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {}
else:
# need to apply to custom param name
custom_param_fqn_config = FqnToConfig({top_level_param_name: c})
quantize_(module, custom_param_fqn_config, filter_fn=None)
missing_keys.discard(full_layer_name)
module._is_hf_initialized = True
for param in module.parameters(recurse=False):
param._is_hf_initialized = True
return {}
return {full_layer_name: value}

Expand Down Expand Up @@ -189,6 +197,8 @@ def convert(
quantize_(module, c, filter_fn=lambda x, fqn: True)
missing_keys.discard(full_layer_name)
module._is_hf_initialized = True
for param in module.parameters(recurse=False):
param._is_hf_initialized = True
return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {}

return {full_layer_name: value}
Expand All @@ -198,6 +208,8 @@ def convert(
quantize_(module, self.hf_quantizer.quantization_config.get_apply_tensor_subclass())
missing_keys.discard(full_layer_name)
module._is_hf_initialized = True
for param in module.parameters(recurse=False):
param._is_hf_initialized = True
return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {}


Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/apertus/configuration_apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class ApertusConfig(PreTrainedConfig):
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
Comment thread
3outeille marked this conversation as resolved.
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/apertus/modular_apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class ApertusConfig(PreTrainedConfig):
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,19 @@ class DeepseekV2Config(PreTrainedConfig):

base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.q_a_proj": "colwise",
"layers.*.self_attn.q_b_proj": "colwise",
"layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj",
"layers.*.self_attn.kv_b_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.experts.gate_up_proj": "packed_colwise",
"layers.*.mlp.experts.down_proj": "rowwise",
"layers.*.mlp.experts": "moe_tp_experts",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
Comment on lines +132 to +137
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK what's the most efficient to avoid having too many coms but LGTM otherwise

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

until we have EP working properly, i think it's okay to leave it this way

"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def forward(
k_nope, value_states = torch.split(k_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim)

q_pe, k_pe = apply_rotary_emb(q_pe, k_pe, position_embeddings.to(q_pe.device))

k_pe = k_pe.expand(*k_nope.shape[:-1], -1)
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/models/deepseek_v2/modular_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,19 @@ class DeepseekV2Config(LlamaConfig):

base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.q_a_proj": "colwise",
"layers.*.self_attn.q_b_proj": "colwise",
"layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj",
"layers.*.self_attn.kv_b_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.experts.gate_up_proj": "packed_colwise",
"layers.*.mlp.experts.down_proj": "rowwise",
"layers.*.mlp.experts": "moe_tp_experts",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

model_type = "deepseek_v2"
Expand Down Expand Up @@ -384,6 +391,7 @@ def forward(
k_nope, value_states = torch.split(k_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim)

q_pe, k_pe = apply_rotary_emb(q_pe, k_pe, position_embeddings.to(q_pe.device))

k_pe = k_pe.expand(*k_nope.shape[:-1], -1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ class DeepseekV3Config(PreTrainedConfig):
model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.mlp.experts.gate_up_proj": "rowwise",
"layers.*.mlp.experts.gate_up_proj": "packed_colwise",
"layers.*.mlp.experts.down_proj": "rowwise",
"layers.*.mlp.experts": "moe_tp_experts",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
Expand Down
Loading