Skip to content

fix Dtensor and tensor mismatch#42906

Merged
ArthurZucker merged 14 commits intomainfrom
fix_dtensor_tensor_moe_mismatch
Dec 16, 2025
Merged

fix Dtensor and tensor mismatch#42906
ArthurZucker merged 14 commits intomainfrom
fix_dtensor_tensor_moe_mismatch

Conversation

@3outeille
Copy link
Copy Markdown
Member

@3outeille 3outeille commented Dec 16, 2025

Bug

local_rowise or local_colwise is calling RowiseParallel(use_dtensor=False) (resp. ColwiseParallel(use_dtensor=False). Issue was first noticed in #42356 , quoting

we would like to not have Dtensor logic in the modeling. For example, sinks are supposed to use local_rowwise (cf main/src/transformers/models/gpt_oss/configuration_gpt_oss.py#L41) which is supposed to not return a Dtensor (cf main/src/transformers/integrations/tensor_parallel.py#L1171) but somehow doesnt work

Fix

def convert_and_load_state_dict_in_model(
	#...
	tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__

was creating the bug because .__class__ will reuse the class default use_dtensor value. Thus, overwritting the value we specifed in local_rowise/colwise.

The fix makes sure to properly use the use_dtensor value and thus no more Dtensor and tensor mismatch

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread src/transformers/core_model_loading.py Outdated
Comment on lines +898 to +903
tp_layer_instance = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]]
tp_layer = tp_layer_instance.__class__
mapping.distributed_operation = tp_layer(
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
)
mapping.distributed_operation.use_dtensor = tp_layer_instance.use_dtensor
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a nice hack, tho if we can come up with a better fix lets try to avoid that please!
the kwargs should only be

device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()

for init, the rest should not be kwargs of init more like hardcoded for that "type".
If you see what I mean here we should only get the class and init it -> local_colwise should get its stuff

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me otherwise, the test will come later AFAIK with your PR on fast distributed tests

router_indices == -1, num_local_experts
) # masking class for one hot
return router_scores, router_indices
return router_logits, router_scores, router_indices
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsure this works for all models but let's see!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I need to fix the Expert parallel anyway so we will see

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42906&sha=12ff9a

@3outeille 3outeille enabled auto-merge (squash) December 16, 2025 17:24
@ArthurZucker ArthurZucker merged commit b1a2fba into main Dec 16, 2025
24 of 26 checks passed
@ArthurZucker ArthurZucker deleted the fix_dtensor_tensor_moe_mismatch branch December 16, 2025 17:36
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* begin Moe test tensor parallel

* create tiny moe model + fix test tensor parallel Moe

eaeaae

* create tiny moe model + fix test tensor parallel Moe

eaeaae

fix tensor parallel MoE test
fix tensor parallel MoE test

* fix backward pass test in tensor parallel for Dense model (huggingface#42811)

* fix

* linting

* use mixtral instead for testing

* fix dtensor and tensor mismatch

* linting

* checkout test tensor parallel to be like main

* avoid hack and create class instead
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants