Skip to content

fix: preserve rotary_pct across save/load cycle in GPTNeoX configs#44985

Merged
zucchini-nlp merged 2 commits intohuggingface:mainfrom
Krishnachaitanyakc:fix/gpt-neox-rotary-pct-roundtrip
Mar 27, 2026
Merged

fix: preserve rotary_pct across save/load cycle in GPTNeoX configs#44985
zucchini-nlp merged 2 commits intohuggingface:mainfrom
Krishnachaitanyakc:fix/gpt-neox-rotary-pct-roundtrip

Conversation

@Krishnachaitanyakc
Copy link
Copy Markdown
Contributor

Summary

Fixes #44913

When creating a GPTNeoXConfig (or GPTNeoXJapaneseConfig) with a non-default rotary_pct, the value is lost after a save_pretrained / from_pretrained round-trip. This happens because convert_rope_params_to_dict unconditionally overwrites partial_rotary_factor with kwargs.pop("rotary_pct", <default>). On reload, rotary_pct is absent from kwargs (it was saved inside rope_parameters), so the default silently replaces the correct value.

The fix uses the same setdefault pattern recommended by @zucchini-nlp from modeling_rope_utils.py L646-648: only set partial_rotary_factor if rotary_pct is explicitly passed; otherwise, use setdefault to preserve any value already present in rope_parameters.

Both GPTNeoXConfig and GPTNeoXJapaneseConfig had the same bug and are fixed together.

Changes

  • src/transformers/models/gpt_neox/configuration_gpt_neox.py: use setdefault for partial_rotary_factor instead of unconditional assignment
  • src/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py: same fix (default 1.0 instead of 0.25)

Test plan

  • Verified rotary_pct=0.5 survives save/load round-trip for both GPTNeoXConfig and GPTNeoXJapaneseConfig
  • Verified default rotary_pct (0.25 for GPTNeoX, 1.0 for Japanese) survives save/load round-trip
  • Ran pytest tests/models/gpt_neox/test_modeling_gpt_neox.py -k config -- all passed
  • Ran pytest tests/models/gpt_neox_japanese/test_modeling_gpt_neox_japanese.py -k config -- all passed
  • ruff check and ruff format --check pass on both files

AI assistance was used in drafting; all changes and tests were reviewed and validated by the submitter.

Use setdefault instead of unconditional assignment for
partial_rotary_factor in GPTNeoXConfig and GPTNeoXJapaneseConfig,
so the value saved in rope_parameters is not overwritten with the
default on reload.
@Rocketknight1
Copy link
Copy Markdown
Member

LGTM! cc @zucchini-nlp to confirm

# Standardize and validate the correctness of rotary position embeddings parameters
# Model uses non-standard naming for rope params, overwrite!
self.rope_parameters.setdefault("rope_theta", kwargs.pop("rotary_emb_base", self.default_theta))
self.rope_parameters["partial_rotary_factor"] = kwargs.pop("rotary_pct", 0.25)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self.rope_parameters.setdefault("partial_rotary_factor", kwargs.pop("rotary_pct", 0.25)) does the same no?

Replace the 4-line if/else block with a single setdefault call,
matching the pattern already used for rope_theta on the line above.
As suggested by @zucchini-nlp in PR review.
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thsnk you

@zucchini-nlp zucchini-nlp enabled auto-merge March 25, 2026 13:36
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gpt_neox, gpt_neox_japanese

@zucchini-nlp zucchini-nlp added this pull request to the merge queue Mar 25, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Mar 25, 2026
@zucchini-nlp zucchini-nlp added this pull request to the merge queue Mar 25, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Mar 25, 2026
@zucchini-nlp zucchini-nlp added this pull request to the merge queue Mar 27, 2026
Merged via the queue into huggingface:main with commit 7b00e3b Mar 27, 2026
22 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Mar 27, 2026
…uggingface#44985)

* fix: preserve rotary_pct across save/load cycle in GPTNeoX configs

Use setdefault instead of unconditional assignment for
partial_rotary_factor in GPTNeoXConfig and GPTNeoXJapaneseConfig,
so the value saved in rope_parameters is not overwritten with the
default on reload.

* refactor: simplify partial_rotary_factor to use setdefault per review

Replace the 4-line if/else block with a single setdefault call,
matching the pattern already used for rope_theta on the line above.
As suggested by @zucchini-nlp in PR review.
NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Mar 30, 2026
…uggingface#44985)

* fix: preserve rotary_pct across save/load cycle in GPTNeoX configs

Use setdefault instead of unconditional assignment for
partial_rotary_factor in GPTNeoXConfig and GPTNeoXJapaneseConfig,
so the value saved in rope_parameters is not overwritten with the
default on reload.

* refactor: simplify partial_rotary_factor to use setdefault per review

Replace the 4-line if/else block with a single setdefault call,
matching the pattern already used for rope_theta on the line above.
As suggested by @zucchini-nlp in PR review.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

In GPTNeoXConfig, rotary_pct silently reverts to default on reload

4 participants