Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
dc6225b
feat(llm): add Hy3-preview (HYV3) SFT support
khazic Apr 27, 2026
e1a1056
ci(llm): add HYV3 phased test configs (P0/P1/P2)
khazic Apr 27, 2026
51520db
ci(llm): rewrite HYV3 test configs to use real checkpoint (DSV4 pattern)
khazic Apr 27, 2026
3e37aeb
fix(llm): align HYV3 configs and model with official Hy3-preview specs
khazic Apr 28, 2026
a555be5
ci(llm): set Hy3-preview checkpoint path to /llm-align/open_models/hu…
khazic Apr 28, 2026
e54e6ae
feat(llm): add HYV3Config and register hy_v3 with AutoConfig
khazic Apr 28, 2026
14895dc
fix(checkpoint): allow partial load when loading HF base checkpoint v…
khazic Apr 28, 2026
fafa76e
fix(checkpoint): skip EP slicing in from_hf for standard DCP load path
khazic Apr 28, 2026
c662e72
ci(llm): use public tencent/Hy3-preview HF path in HYV3 test yamls
khazic Apr 28, 2026
0c71b04
ci(llm): remove P2 DeepEP yaml (not yet validated)
khazic Apr 28, 2026
665ddb5
fix(llm): disable e_score_correction_bias EMA update for HYV3
khazic Apr 28, 2026
564ff4f
fix(llm): guard update_moe_gate_bias against disabled bias update
khazic Apr 28, 2026
f631d36
docs(llm): add model-coverage page for Hy3-preview (HYV3ForCausalLM)
khazic Apr 28, 2026
b7b01c8
docs(llm): add tencent/Hy3-preview to LLM model-coverage toctree
khazic Apr 28, 2026
4eb8133
chore(llm): tune Hy3-preview DeepEP recipe for 16-node run
HuiyingLi Apr 28, 2026
cb30a59
fix(llm): load Hy3-preview MoE expert weights from HF checkpoint
HuiyingLi Apr 28, 2026
91d268c
chore(llm): env-gated parity dumps in train_ft.py setup()
HuiyingLi Apr 28, 2026
a023500
Revert "chore(llm): env-gated parity dumps in train_ft.py setup()"
HuiyingLi Apr 28, 2026
0265f01
Apply suggestion from @jgerh
HuiyingLi Apr 28, 2026
f60a2da
Apply suggestion from @jgerh
HuiyingLi Apr 28, 2026
b99084c
Apply suggestion from @jgerh
HuiyingLi Apr 28, 2026
3a1b2a4
Apply suggestion from @jgerh
HuiyingLi Apr 28, 2026
2b24efd
chore(llm): drop Hy3 4-layer smoke/ckpt example yamls
HuiyingLi Apr 28, 2026
47c25db
chore(checkpoint): drop redundant changes already on main
HuiyingLi Apr 28, 2026
66780c9
test(llm): unit tests for HYV3StateDictAdapter
HuiyingLi Apr 28, 2026
8f56853
test(llm): cover remaining HYV3 + recipe changes in this PR
HuiyingLi Apr 28, 2026
1d511f8
docs(llm): announce Hy3-preview in README and latest-models log
HuiyingLi Apr 28, 2026
fc62a07
Merge remote-tracking branch 'origin/main' into feat/hy3-sft
HuiyingLi Apr 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

## 📣 News and Discussions
- [04/25/2026][**DeepSeek V4 Flash**](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) We now support finetuning `deepseek-ai/DeepSeek-V4-Flash`, thanks to [@Khazic](https://github.com/khazic). Check out our [recipe](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/deepseek_v4/deepseek_v4_flash_hellaswag.yaml) and [guide](https://github.com/NVIDIA-NeMo/Automodel/blob/main/docs/guides/llm/dsv4-flash.md).
- [04/28/2026][**Hy3-preview**](https://huggingface.co/tencent/Hy3-preview) We now support finetuning `tencent/Hy3-preview`, thanks to [@Khazic](https://github.com/khazic). Check out our [recipe](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml).
- [04/22/2026][**Qwen3.6-27B**](https://huggingface.co/Qwen/Qwen3.6-27B) We now support finetuning `Qwen/Qwen3.6-27B`. Check out our [recipe](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/qwen3_5/qwen3_6_27b.yaml).
- [04/20/2026][**Qwen-Image**](https://huggingface.co/Qwen/Qwen-Image) We now support finetuning `Qwen/Qwen-Image`, thanks to [@harshareddy832](https://github.com/harshareddy832). Check out our [recipe](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/diffusion/finetune/qwen_image_t2i_flow.yaml).
- [04/16/2026][**Qwen3.6 MoE**](https://huggingface.co/Qwen/Qwen3.6-35B-A3B) We now support finetuning `Qwen/Qwen3.6-35B-A3B`. Check out our [recipe](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/qwen3_5_moe/qwen3_6_35b.yaml).
Expand Down
1 change: 1 addition & 0 deletions docs/model-coverage/latest-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ See the [Model Coverage Overview](overview.md) for release summaries, and the [L

| Date | Model | HF Model ID | Modality | Recipe | Try on Brev |
|------|-------|-------------|----------|--------|------|
| 2026-04-28 | Hy3-preview | [`tencent/Hy3-preview`](https://huggingface.co/tencent/Hy3-preview) | LLM | [hy3_preview_deepep.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml) | 🚧 |
| 2026-04-25 | DeepSeek V4 Flash | [`deepseek-ai/DeepSeek-V4-Flash`](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) | LLM | [deepseek_v4_flash_hellaswag.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/deepseek_v4/deepseek_v4_flash_hellaswag.yaml) | 🚧 |
| 2026-04-22 | Qwen3.6-27B | [`Qwen/Qwen3.6-27B`](https://huggingface.co/Qwen/Qwen3.6-27B) | VLM | [qwen3_6_27b.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/qwen3_5/qwen3_6_27b.yaml) | 🚧 |
| 2026-04-16 | LLaVA-OneVision-1.5 (4B / 8B) | [`lmms-lab/LLaVA-OneVision-1.5-4B-Instruct`](https://huggingface.co/lmms-lab/LLaVA-OneVision-1.5-4B-Instruct) | VLM | [llava_ov_1_5_4b_finetune.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/llava_onevision/llava_ov_1_5_4b_finetune.yaml) | 🚧 |
Expand Down
6 changes: 4 additions & 2 deletions docs/model-coverage/llm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ To run LLMs with NeMo AutoModel, make sure you're using NeMo container version [
pip3 install --upgrade git+git@github.com:NVIDIA-NeMo/AutoModel.git
```

For other installation options (e.g., uv), please see our [Installation Guide](../../guides/installation.md).
For other installation options (e.g., uv), see the [NeMo AutoModel Installation Guide](../../guides/installation.md).

## Supported Models

Expand Down Expand Up @@ -71,6 +71,7 @@ NeMo AutoModel supports the [AutoModelForCausalLM](https://huggingface.co/transf
| Stability AI | [StableLM](stabilityai/stablelm.md) | `StableLmForCausalLM` |
| Stepfun AI | [Step-3.5](stepfun-ai/step-3-5.md) | `Step3p5ForCausalLM` |
| Parasail AI | [GritLM](parasail-ai/gritlm.md) | `GritLM` |
| Tencent | [Hy3-preview](tencent/hy3.md) | `HYV3ForCausalLM` |

## Fine-Tuning LLMs with NeMo AutoModel

Expand All @@ -79,7 +80,7 @@ The models listed above can be fine-tuned using NeMo AutoModel. We support two p
1. **Parameter-Efficient Fine-Tuning (PEFT)**: Updates only a small subset of parameters (typically <1%) using techniques like Low-Rank Adaptation (LoRA).
2. **Supervised Fine-Tuning (SFT)**: Updates all or most model parameters for deeper adaptation.

Please see our [Fine-Tuning Guide](../../guides/llm/finetune.md) to learn how to apply both methods to your data.
See the [Fine-Tuning Guide](../../guides/llm/finetune.md) to learn how to apply both methods to your data.

:::{tip}
In these guides, we use the `SQuAD v1.1` dataset for demonstration purposes, but you can use your own data. Update the recipe YAML `dataset` / `validation_dataset` sections accordingly. See [LLM datasets](../../guides/llm/dataset.md) and [dataset overview](../../guides/dataset-overview.md).
Expand Down Expand Up @@ -140,4 +141,5 @@ orionstar/orion
stabilityai/stablelm
stepfun-ai/step-3-5
parasail-ai/gritlm
tencent/hy3
```
63 changes: 63 additions & 0 deletions docs/model-coverage/llm/tencent/hy3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Hy3 (HunyuanLarge)

[Hy3-preview](https://huggingface.co/tencent/Hy3-preview) is a 295B Mixture-of-Experts language model from Tencent. It features 80 transformer layers (layer 0 dense, layers 1–79 MoE), 192 routed experts plus 1 shared expert with top-8 sigmoid routing, Grouped Query Attention (64 Q / 8 KV heads), per-head QK RMSNorm, RoPE, and an `e_score_correction_bias` gate buffer for expert-load correction. It supports a 256K context window.

:::{card}
| | |
|---|---|
| **Task** | Text Generation (MoE) |
| **Architecture** | `HYV3ForCausalLM` |
| **Parameters** | 295B total |
| **HF Org** | [tencent](https://huggingface.co/tencent) |
:::

## Available Models

- **Hy3-preview**: 295B total, top-8 routed experts activated per token

## Architectures

- `HYV3ForCausalLM`

## Example HF Models

| Model | HF ID |
|---|---|
| Hy3-preview | [`tencent/Hy3-preview`](https://huggingface.co/tencent/Hy3-preview) |

## Example Recipes

| Recipe | Description |
|---|---|
| {download}`hy3_preview_deepep.yaml <../../../../examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml>` | SFT — Hy3-preview with DeepEP |

## Try with NeMo AutoModel

**1. Install** ([NeMo AutoModel](../../../guides/installation.md)):

```bash
pip install nemo-automodel
```

**2. Clone the repo** to get the example recipes:

```bash
git clone https://github.com/NVIDIA-NeMo/Automodel.git
cd Automodel
```

**3. Run the recipe** from inside the repo:

```bash
automodel --nproc-per-node=8 examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml
```

See the [NeMo AutoModel Installation Guide](../../../guides/installation.md) and [LLM Fine-Tuning Guide](../../../guides/llm/finetune.md).

## Fine-Tuning

See the [LLM Fine-Tuning Guide](../../../guides/llm/finetune.md) and the [Large MoE Fine-Tuning Guide](../../../guides/llm/large-moe-finetune.md).

## Hugging Face Model Cards

- [tencent/Hy3-preview](https://huggingface.co/tencent/Hy3-preview)
134 changes: 134 additions & 0 deletions examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SFT recipe for tencent/Hy3-preview (295B MoE, 192 experts top-8, 256K context).
# Requires transformers >= 5.6.0 for the hy_v3 architecture.
#
# Hardware: 8 GPUs (80 GB+ each) for LoRA; 32 GPUs for full fine-tuning.
# automodel examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml --nproc-per-node 8
#
# EP size must divide num_experts (192): e.g. ep_size=8 gives 24 experts/rank.

recipe: TrainFinetuneRecipeForNextTokenPrediction

step_scheduler:
global_batch_size: 256
local_batch_size: 8
ckpt_every_steps: 500
val_every_steps: 500
num_epochs: 1
max_steps: 100

dist_env:
backend: nccl
timeout_minutes: 30

rng:
_target_: nemo_automodel.components.training.rng.StatefulRNG
seed: 1111
ranked: true

model:
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: tencent/Hy3-preview
torch_dtype: bfloat16
backend:
_target_: nemo_automodel.components.models.common.BackendConfig
attn: te
linear: torch
rms_norm: torch_fp32
experts: torch_mm
dispatcher: deepep
fake_balanced_gate: false
gate_precision: float32
enable_hf_state_dict_adapter: true
enable_fsdp_optimizations: true

checkpoint:
enabled: true
checkpoint_dir: /tmp/checkpoints/hy3_preview/
model_save_format: safetensors
save_consolidated: true

distributed:
strategy: fsdp2
tp_size: 1
cp_size: 1
pp_size: 4
ep_size: 32 # 192 experts / 8 ranks = 24 experts per rank

sequence_parallel: false
activation_checkpointing: true

pipeline:
pp_schedule: 1f1b
pp_microbatch_size: 1
round_virtual_stages_to_pp_multiple: down
scale_grads_in_schedule: false
patch_inner_model: false
patch_causal_lm_model: false

moe:
reshard_after_forward: false
wrap_outer_model: false

loss_fn:
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy

dataset:
_target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
path_or_dataset: rowan/hellaswag
split: train
tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: tencent/Hy3-preview

packed_sequence:
packed_sequence_size: 0

dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn:
_target_: nemo_automodel.components.datasets.utils.default_collater
pad_seq_len_divisible: 64
shuffle: true

validation_dataset:
_target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
path_or_dataset: rowan/hellaswag
split: validation
num_samples_limit: 64
tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: tencent/Hy3-preview

validation_dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn:
_target_: nemo_automodel.components.datasets.utils.default_collater
pad_seq_len_divisible: 64
shuffle: false
drop_last: true

optimizer:
_target_: torch.optim.AdamW
betas: [0.9, 0.95]
eps: 1e-8
lr: 1e-5
weight_decay: 0.0

# Uncomment for W&B logging
# wandb:
# project: hy3-preview-sft
# name: hy3_preview_deepep
5 changes: 5 additions & 0 deletions nemo_automodel/_transformers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@
"LLaVAOneVision1_5_ForConditionalGeneration",
),
),
(
"HYV3ForCausalLM",
("nemo_automodel.components.models.hy_v3.model", "HYV3ForCausalLM"),
),
(
"Qwen2ForCausalLM",
("nemo_automodel.components.models.qwen2.model", "Qwen2ForCausalLM"),
Expand Down Expand Up @@ -182,6 +186,7 @@
_CUSTOM_CONFIG_REGISTRATIONS: Dict[str, Tuple[str, str]] = {
"baichuan": ("nemo_automodel.components.models.baichuan.configuration", "BaichuanConfig"),
"deepseek_v4": ("nemo_automodel.components.models.deepseek_v4.config", "DeepseekV4Config"),
"hy_v3": ("nemo_automodel.components.models.hy_v3.config", "HYV3Config"),
"kimi_k25": ("nemo_automodel.components.models.kimi_k25_vl.model", "KimiK25VLConfig"),
"kimi_vl": ("nemo_automodel.components.models.kimivl.model", "KimiVLConfig"),
"llavaonevision1_5": ("nemo_automodel.components.models.llava_onevision.model", "Llavaonevision1_5Config"),
Expand Down
17 changes: 17 additions & 0 deletions nemo_automodel/components/models/hy_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_automodel.components.models.hy_v3.model import HYV3ForCausalLM

__all__ = ["HYV3ForCausalLM"]
100 changes: 100 additions & 0 deletions nemo_automodel/components/models/hy_v3/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from transformers import PretrainedConfig


class HYV3Config(PretrainedConfig):
"""Configuration class for Tencent Hy3-preview (295B MoE).

Architecture:
- 80 transformer layers; layer 0 is dense, layers 1-79 are MoE
- MoE: 192 routed experts + 1 shared expert, top-8 activated
- Sigmoid routing with expert-bias correction (e_score_correction_bias)
- GQA: 64 Q heads, 8 KV heads, head_dim=128
- Per-head QK RMSNorm before RoPE
- 256K context, rope_theta=11158840
"""

model_type = "hy_v3"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size: int = 129280,
hidden_size: int = 4096,
intermediate_size: int = 1536,
moe_intermediate_size: int = 1536,
num_hidden_layers: int = 80,
num_attention_heads: int = 64,
num_key_value_heads: int = 8,
head_dim: int = 128,
# MoE routing
num_experts: int = 192,
num_shared_experts: int = 1,
num_experts_per_tok: int = 8,
router_scaling_factor: float = 1.0,
route_norm: bool = False,
moe_router_enable_expert_bias: bool = True,
# Dense layers
first_k_dense_replace: int = 1,
# Position encoding
max_position_embeddings: int = 262144,
rope_theta: float = 11158840.0,
rope_scaling: dict | None = None,
# Standard options
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
hidden_act: str = "silu",
use_cache: bool = True,
pad_token_id: int | None = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
torch_dtype: str = "bfloat16",
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.num_experts = num_experts
self.num_shared_experts = num_shared_experts
self.num_experts_per_tok = num_experts_per_tok
self.router_scaling_factor = router_scaling_factor
self.route_norm = route_norm
self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
self.first_k_dense_replace = first_k_dense_replace
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.hidden_act = hidden_act
self.torch_dtype = torch_dtype

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
use_cache=use_cache,
**kwargs,
)
Loading
Loading