Add base_model_tp_plan to OlmoeConfig#44668
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a default tensor-parallel (TP) sharding plan to OlmoeConfig so OLMoE models can be loaded with from_pretrained(tp_plan="auto"), and wires OLMoE’s modeling tests into the shared TP test mixin.
Changes:
- Define
base_model_tp_planonOlmoeConfig. - Add
TensorParallelTesterMixinto OLMoE’s model test class. - Document why q/k norms need special handling for TP.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
src/transformers/models/olmoe/configuration_olmoe.py |
Introduces a default TP plan mapping for OLMoE modules. |
tests/models/olmoe/test_modeling_olmoe.py |
Adds TP test mixin to validate TP behavior for OLMoE. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Hey @dacorvo! YOu'll need a small rebase as we recently migrated configs to dataclasses! Apart from that, did you test the tp plan on real checkpoints to ensure correctness by any chance? |
Rebased onto main after configs were migrated to dataclasses. Adds base_model_tp_plan as a class attribute and TensorParallelTesterMixin to the OLMoE test suite. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
@Cyrilvallez I rebased the branch.
As a sanity, I also tested the wrong initial plan I submitted, and verified it failed for every prompt 😅 (divergence after a few tokens). |
|
@bot /style |
|
Style fix bot fixed some files and pushed the changes. |
| if is_torch_available(): | ||
| OlmoeModelTester.causal_lm_class = OlmoeForCausalLM |
There was a problem hiding this comment.
Actually, did not see this but let's not do that 😅 Let's put it directly in the class above!
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
|
[For maintainers] Suggested jobs to run (before merge) run-slow: olmoe |
* Rebase: Add base_model_tp_plan to OlmoeConfig (dataclass style) Rebased onto main after configs were migrated to dataclasses. Adds base_model_tp_plan as a class attribute and TensorParallelTesterMixin to the OLMoE test suite. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Apply repo consistency fixes * review: update src/transformers/models/olmoe/configuration_olmoe.py Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * review: use correct pattern for OlmoeModelTester class --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
* Rebase: Add base_model_tp_plan to OlmoeConfig (dataclass style) Rebased onto main after configs were migrated to dataclasses. Adds base_model_tp_plan as a class attribute and TensorParallelTesterMixin to the OLMoE test suite. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Apply repo consistency fixes * review: update src/transformers/models/olmoe/configuration_olmoe.py Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * review: use correct pattern for OlmoeModelTester class --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
Fixes #44677
Summary
base_model_tp_plantoOlmoeConfig, enablingfrom_pretrained(tp_plan="auto")for OLMoE modelsTensorParallelTesterMixinto OLMoE tests for TP validation coverage"colwise"forq_normandk_normbecause OLMoE applies these norms after the q/k projections — norm weight dimensions must match the sharded projection outputDesign note
Qwen3-MoE uses
"replicated_with_grad_allreduce"for its q/k norms, but that only works because Qwen3 applies norms before the projections (on the full hidden state). OLMoE's architecture applies norms after projections, so the norm weights must be sharded the same way as the projection output — hence"colwise".Test plan
python -m pytest tests/models/olmoe/test_modeling_olmoe.py::OlmoeModelTest::test_tp_plan_matches_params -xvs— passesAI disclosure
This PR was developed with AI assistance (Claude). All changes reviewed and validated by a human contributor.
🤖 Generated with Claude Code