Skip to content

Add HyperCLOVAX SEED Think 14B#44956

Open
bigshanedogg wants to merge 6 commits intohuggingface:mainfrom
bigshanedogg:feat/hyperclovax
Open

Add HyperCLOVAX SEED Think 14B#44956
bigshanedogg wants to merge 6 commits intohuggingface:mainfrom
bigshanedogg:feat/hyperclovax

Conversation

@bigshanedogg
Copy link
Copy Markdown

@bigshanedogg bigshanedogg commented Mar 23, 2026

What does this PR do?

Adds native Transformers support for HyperCLOVA X SEED Think 14B, a 14.74B-parameter Korean reasoning LLM developed by NAVER Cloud.

Architecture

LLaMA-style decoder-only transformer with two modifications:

  • Peri-Layer Normalization (use_post_norm): an extra RMSNorm is applied after each
    sub-layer output (both attention and MLP), in addition to the standard pre-norm.
  • Maximal Update Parametrization (μP): four per-config scaling factors replace fixed constants:
    • attention_multiplier — replaces 1/sqrt(head_dim) in attention
    • residual_multiplier — scales each sub-layer output before adding to the residual stream
    • embedding_multiplier — scales the token embedding output
    • logits_scaling — scales final logits before softmax / sampling

Implementation approach

Following the maintainer's guidance in #44957, this PR uses the modular system (modular_hyperclovax.py) to minimise LOC and make the diff easy to review-iterate. (Roughly 59% of lines are generated rather than manually maintained.)

The maintainer suggested inheriting the decoder layer with post-norms from GLM4. After evaluation, Granite was chosen as the decoder layer base instead, for the following reasons:

  • use_post_norm is optional (False by default). GLM4's decoder layer has post-norms always on — inheriting from it would require logic to conditionally disable post_self_attn_layernorm / post_mlp_layernorm, adding complexity rather than reducing it.
  • Granite's decoder layer already provides residual_multiplier (always-active MuP). When use_post_norm=False, HyperCLOVAXDecoderLayer is identical to GraniteDecoderLayer — zero extra code.
  • Using GLM4 would require adding both residual_multiplier and conditionally disabling its built-in norms — two changes in opposite directions for no net gain in code reuse.

All other modules (RMSNorm, MLP, Attention, etc.) are inherited from Granite unchanged. The modular file is a few hundred LOC as suggested.

Benchmark validation

Tasks Metric vLLM this PR
hellaswag (non-think) acc_norm 0.6521 0.6666
gsm8k (non-think) flexible-extract 0.9151 0.9188

External support

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

A code agent was used for mechanical tasks such as aligning docstrings and comments. The core implementation was written by the submitter directly, who has reviewed every changed line and personally run the tests including benchmark validation.

Before submitting

HanFa added a commit to HanFa/vllm that referenced this pull request Mar 29, 2026
Vendor the HyperCLOVAX Vision config into vLLM to fix transformers v5
compatibility. The upstream remote code config does not handle empty
initialization (text_config=None), which breaks v5's @strict config
validation added in huggingface/transformers#41250.

Fixes: vllm-project#38387

TODO: Remove vendored config once HyperCLOVAX is upstreamed to
transformers. Tracking PR: huggingface/transformers#44956

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@bigshanedogg bigshanedogg marked this pull request as ready for review March 29, 2026 21:57
@bigshanedogg bigshanedogg changed the title [WIP] Add HyperCLOVAX model Add HyperCLOVAX model Mar 29, 2026
@bigshanedogg
Copy link
Copy Markdown
Author

@zucchini-nlp ,
Following your suggestion, I implemented this in a modular way by inheriting from Granite, incorporated the changes from #44957, and completed benchmark validation.

All CI checks have completed, except for one job that is still pending its status report.
Would it be okay to request a review at this stage?

Copy link
Copy Markdown
Author

@bigshanedogg bigshanedogg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a self-review of the key changes in this PR.

Comment thread src/transformers/models/hyperclovax/modular_hyperclovax.py Outdated
Comment on lines +98 to +101
attention_multiplier: float | None = None
residual_multiplier: float | None = None
embedding_multiplier: float | None = None
logits_scaling: float | None = None
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fields also exist in Granite, but are defined here due to a different default values.
Although they are present in config.json, if not explicitly declared, the dynamic default value setting in post_init will not be applied.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part has been removed based on the modification noted in the comment below, except for attention_multiplier.

Comment on lines +165 to +168
# Peri-Layer Normalization: additional RMSNorm after each sub-layer output
if self.use_post_norm:
self.post_norm1 = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_norm2 = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When self.use_post_norm is True,
post_norm for both attention and MLP are declared separately to match the Peri-LN structure.
Since there is a branch on self.use_post_norm, Granite is inherited instead of GLM4
(field similarity with Granite was also greater).

HanFa added a commit to HanFa/vllm that referenced this pull request Mar 31, 2026
Vendor the HyperCLOVAX Vision config into vLLM to fix transformers v5
compatibility. The upstream remote code config does not handle empty
initialization (text_config=None), which breaks v5's @strict config
validation added in huggingface/transformers#41250.

Fixes: vllm-project#38387

TODO: Remove vendored config once HyperCLOVAX is upstreamed to
transformers. Tracking PR: huggingface/transformers#44956

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Fang Han <fhan0520@gmail.com>
@bigshanedogg bigshanedogg changed the title Add HyperCLOVAX model Add HyperCLOVAX SEED Think 14B Mar 31, 2026
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on applying modular! I left a few comments on what can be deleted because it's already auto-resolved by modular

Other than that we're fine. After addressing the comments, will request core maintainer review and we'll merge

Comment thread docs/source/en/model_doc/hyperclovax.md Outdated
Comment thread docs/source/en/model_doc/hyperclovax.md Outdated
Comment thread src/transformers/models/hyperclovax/modular_hyperclovax.py Outdated
Comment thread src/transformers/models/hyperclovax/modular_hyperclovax.py Outdated
Comment thread src/transformers/models/hyperclovax/modular_hyperclovax.py Outdated
Comment thread src/transformers/models/hyperclovax/modular_hyperclovax.py Outdated
Comment thread src/transformers/models/hyperclovax/modular_hyperclovax.py Outdated
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
# MuP: multiply logits by logits_scaling (cf. GraniteForCausalLM which divides)
logits = self.lm_head(hidden_states[:, slice_indices, :]) * self.config.logits_scaling
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we adjust scaling, so we can copy fully? For ex in config self.logits_scaling = 1 / self.logits_scaling

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!
However, I'm a bit concerned that storing the inverted value in Config.logits_scaling could cause confusion,
since users inspecting config.json would see a different value than what's actually used in the forward pass.
Would it be okay to keep the explicit * self.config.logits_scaling in forward for clarity, even if it means a small override?

Comment thread src/transformers/models/hyperclovax/modular_hyperclovax.py Outdated
Comment thread tests/models/hyperclovax/test_modeling_hyperclovax.py Outdated
@zucchini-nlp
Copy link
Copy Markdown
Member

run-slow: hyperclovax

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 2, 2026

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/hyperclovax"]
quantizations: []

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 2, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 7d1b9113 workflow commit (merge commit)
PR 6aa22bc3 branch commit (from PR)
main bb803105 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@bigshanedogg
Copy link
Copy Markdown
Author

@zucchini-nlp,
Thank you for the thorough review!
I've addressed all the feedback and removed quite a few unnecessary lines. For the logits_scaling part, I've left an additional comment as I wasn't sure if it might cause confusion.
The model behavior has been verified to remain unchanged after the edits.

Some of the failed tests appear to be outside the scope of this PR (e.g., VibeVoiceAsrForConditionalGenerationModelTest).
I will investigate the remaining cases related to HyperCLOVAX.

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, to fix the CI you need to run make fix-repo. I merged main which will fix unrelated failures, and requestd a core maintainer's review

@@ -0,0 +1,27 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few files left wrt 2026 😄

@zucchini-nlp zucchini-nlp requested review from vasqu and removed request for ArthurZucker and Rocketknight1 April 7, 2026 13:34
@zucchini-nlp
Copy link
Copy Markdown
Member

run-slow: hyperclovax

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oke, seeing a bad rebase with unrelated diff 😄 and a tiny change in rope doc. I will pass-over the latest diff after the bad rebase is fixed, and prob a core maintainer will pass over soon

Comment thread .github/workflows/trl-ci-bot.yml Outdated
Comment thread docs/source/en/internal/rope_utils.md Outdated
@bigshanedogg bigshanedogg force-pushed the feat/hyperclovax branch 2 times, most recently from 331ed88 to 9600edb Compare April 10, 2026 08:46
@bigshanedogg
Copy link
Copy Markdown
Author

@zucchini-nlp ,
I've incorporated the suggested changes and reverted to your last reviewed commit (c025d918).
Really appreciate you taking the time to look into this!

Comment thread src/transformers/models/blip/image_processing_blip.py Outdated
@zucchini-nlp
Copy link
Copy Markdown
Member

@bigshanedogg , one tiny unrelated diff left-out. And vasqu will come to review next week :)

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, hyperclovax

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44956&sha=d5a047

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 21, 2026

Sorry for all the delays, will be taking a look today!!

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only some nits tbh, looks overall super good! Let's sync with main and fixup the last details 🤗

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.bfloat16,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dtype=torch.bfloat16,

shouldnt need this anymore, we use dtype="auto" by default nowadays

**model_inputs,
max_new_tokens=200,
tokenizer=tokenizer,
stop_strings=["<|endofturn|>", "<|stop|>"],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Might be nice to add this to the generation config instead maybe

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We changed this on main, you don't need to manually add these here anymore - just run python utils/check_auto.py --fix_and_overwrite for auto mapping to register these (only for the configs)

("groupvit", "CLIPTokenizer" if is_tokenizers_available() else None),
("herbert", "HerbertTokenizer" if is_tokenizers_available() else None),
("hubert", "Wav2Vec2CTCTokenizer"),
("hyperclovax", "TokenizersBackend" if is_tokenizers_available() else None),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
("hyperclovax", "TokenizersBackend" if is_tokenizers_available() else None),

should not be needed, we auto fallback to tokenizers backend. Could you double check

Comment on lines +16 to +22
HyperCLOVAX is a decoder-only transformer based on Granite with the following modifications:

- **Maximal Update Parametrization (MuP)**: uses per-config scaling factors
(`attention_multiplier`, `residual_multiplier`, `embedding_multiplier`, `logits_scaling`)
to enable stable training across model sizes.
- **Peri-Layer Normalization** (optional): applies an extra RMSNorm after each
sub-layer output when `use_post_norm=True`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
HyperCLOVAX is a decoder-only transformer based on Granite with the following modifications:
- **Maximal Update Parametrization (MuP)**: uses per-config scaling factors
(`attention_multiplier`, `residual_multiplier`, `embedding_multiplier`, `logits_scaling`)
to enable stable training across model sizes.
- **Peri-Layer Normalization** (optional): applies an extra RMSNorm after each
sub-layer output when `use_post_norm=True`.

Nit: we dont really specify the architecture like this in the modular/modeling code - I think it suffices within the model_doc

@@ -0,0 +1,225 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.

# Same as Granite — avoids edge cases with the causal_mask buffer during CPU offload
model_split_percents = [0.5, 0.7, 0.8]

_torch_compile_train_cls = HyperCLOVAXForCausalLM if is_torch_available() else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_torch_compile_train_cls = HyperCLOVAXForCausalLM if is_torch_available() else None

shouldnt be needed tbh, can you check?

Comment on lines +107 to +115
@unittest.skip(
"In TP mode, Float8 quantization derives scales per shard rather than globally, "
"so each TP rank observes different weight magnitudes than the full-weight non-TP "
"baseline. HyperCLOVAX's Peri-Layer Normalization (post_norm1/post_norm2) amplifies "
"this discrepancy past the 75% token-match threshold. Skipped pending an upstream fix."
)
@is_tensor_parallel_test
def test_tp_generation_quantized(self):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, cc @3outeille @SunMarc just for viz

expected_slice = expected_slices.get_expectation().to(torch_device)
self.assertTrue(torch.allclose(out.logits[0, 0, :15].float(), expected_slice, atol=1e-2, rtol=1e-2))

@require_torch_large_accelerator
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@require_torch_large_accelerator


self.assertEqual(output_text, EXPECTED_TEXTS)

@require_torch_large_accelerator
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@require_torch_large_accelerator

i dont think we need these anymore

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants