diff --git a/README.md b/README.md index 9fdd9db6f..004fdcfd3 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ ## 📣 News and Discussions - [04/25/2026][**DeepSeek V4 Flash**](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) We now support finetuning `deepseek-ai/DeepSeek-V4-Flash`, thanks to [@Khazic](https://github.com/khazic). Check out our [recipe](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/deepseek_v4/deepseek_v4_flash_hellaswag.yaml) and [guide](https://github.com/NVIDIA-NeMo/Automodel/blob/main/docs/guides/llm/dsv4-flash.md). +- [04/28/2026][**Hy3-preview**](https://huggingface.co/tencent/Hy3-preview) We now support finetuning `tencent/Hy3-preview`, thanks to [@Khazic](https://github.com/khazic). Check out our [recipe](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml). - [04/22/2026][**Qwen3.6-27B**](https://huggingface.co/Qwen/Qwen3.6-27B) We now support finetuning `Qwen/Qwen3.6-27B`. Check out our [recipe](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/qwen3_5/qwen3_6_27b.yaml). - [04/20/2026][**Qwen-Image**](https://huggingface.co/Qwen/Qwen-Image) We now support finetuning `Qwen/Qwen-Image`, thanks to [@harshareddy832](https://github.com/harshareddy832). Check out our [recipe](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/diffusion/finetune/qwen_image_t2i_flow.yaml). - [04/16/2026][**Qwen3.6 MoE**](https://huggingface.co/Qwen/Qwen3.6-35B-A3B) We now support finetuning `Qwen/Qwen3.6-35B-A3B`. Check out our [recipe](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/qwen3_5_moe/qwen3_6_35b.yaml). diff --git a/docs/model-coverage/latest-models.md b/docs/model-coverage/latest-models.md index efbdbe5ba..dcd27080f 100644 --- a/docs/model-coverage/latest-models.md +++ b/docs/model-coverage/latest-models.md @@ -6,6 +6,7 @@ See the [Model Coverage Overview](overview.md) for release summaries, and the [L | Date | Model | HF Model ID | Modality | Recipe | Try on Brev | |------|-------|-------------|----------|--------|------| +| 2026-04-28 | Hy3-preview | [`tencent/Hy3-preview`](https://huggingface.co/tencent/Hy3-preview) | LLM | [hy3_preview_deepep.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml) | 🚧 | | 2026-04-25 | DeepSeek V4 Flash | [`deepseek-ai/DeepSeek-V4-Flash`](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) | LLM | [deepseek_v4_flash_hellaswag.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/deepseek_v4/deepseek_v4_flash_hellaswag.yaml) | 🚧 | | 2026-04-22 | Qwen3.6-27B | [`Qwen/Qwen3.6-27B`](https://huggingface.co/Qwen/Qwen3.6-27B) | VLM | [qwen3_6_27b.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/qwen3_5/qwen3_6_27b.yaml) | 🚧 | | 2026-04-16 | LLaVA-OneVision-1.5 (4B / 8B) | [`lmms-lab/LLaVA-OneVision-1.5-4B-Instruct`](https://huggingface.co/lmms-lab/LLaVA-OneVision-1.5-4B-Instruct) | VLM | [llava_ov_1_5_4b_finetune.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/llava_onevision/llava_ov_1_5_4b_finetune.yaml) | 🚧 | diff --git a/docs/model-coverage/llm/index.md b/docs/model-coverage/llm/index.md index 6122aaf10..cc75e2b91 100644 --- a/docs/model-coverage/llm/index.md +++ b/docs/model-coverage/llm/index.md @@ -11,7 +11,7 @@ To run LLMs with NeMo AutoModel, make sure you're using NeMo container version [ pip3 install --upgrade git+git@github.com:NVIDIA-NeMo/AutoModel.git ``` -For other installation options (e.g., uv), please see our [Installation Guide](../../guides/installation.md). +For other installation options (e.g., uv), see the [NeMo AutoModel Installation Guide](../../guides/installation.md). ## Supported Models @@ -71,6 +71,7 @@ NeMo AutoModel supports the [AutoModelForCausalLM](https://huggingface.co/transf | Stability AI | [StableLM](stabilityai/stablelm.md) | `StableLmForCausalLM` | | Stepfun AI | [Step-3.5](stepfun-ai/step-3-5.md) | `Step3p5ForCausalLM` | | Parasail AI | [GritLM](parasail-ai/gritlm.md) | `GritLM` | +| Tencent | [Hy3-preview](tencent/hy3.md) | `HYV3ForCausalLM` | ## Fine-Tuning LLMs with NeMo AutoModel @@ -79,7 +80,7 @@ The models listed above can be fine-tuned using NeMo AutoModel. We support two p 1. **Parameter-Efficient Fine-Tuning (PEFT)**: Updates only a small subset of parameters (typically <1%) using techniques like Low-Rank Adaptation (LoRA). 2. **Supervised Fine-Tuning (SFT)**: Updates all or most model parameters for deeper adaptation. -Please see our [Fine-Tuning Guide](../../guides/llm/finetune.md) to learn how to apply both methods to your data. +See the [Fine-Tuning Guide](../../guides/llm/finetune.md) to learn how to apply both methods to your data. :::{tip} In these guides, we use the `SQuAD v1.1` dataset for demonstration purposes, but you can use your own data. Update the recipe YAML `dataset` / `validation_dataset` sections accordingly. See [LLM datasets](../../guides/llm/dataset.md) and [dataset overview](../../guides/dataset-overview.md). @@ -140,4 +141,5 @@ orionstar/orion stabilityai/stablelm stepfun-ai/step-3-5 parasail-ai/gritlm +tencent/hy3 ``` diff --git a/docs/model-coverage/llm/tencent/hy3.md b/docs/model-coverage/llm/tencent/hy3.md new file mode 100644 index 000000000..ea2d8be1e --- /dev/null +++ b/docs/model-coverage/llm/tencent/hy3.md @@ -0,0 +1,63 @@ +# Hy3 (HunyuanLarge) + +[Hy3-preview](https://huggingface.co/tencent/Hy3-preview) is a 295B Mixture-of-Experts language model from Tencent. It features 80 transformer layers (layer 0 dense, layers 1–79 MoE), 192 routed experts plus 1 shared expert with top-8 sigmoid routing, Grouped Query Attention (64 Q / 8 KV heads), per-head QK RMSNorm, RoPE, and an `e_score_correction_bias` gate buffer for expert-load correction. It supports a 256K context window. + +:::{card} +| | | +|---|---| +| **Task** | Text Generation (MoE) | +| **Architecture** | `HYV3ForCausalLM` | +| **Parameters** | 295B total | +| **HF Org** | [tencent](https://huggingface.co/tencent) | +::: + +## Available Models + +- **Hy3-preview**: 295B total, top-8 routed experts activated per token + +## Architectures + +- `HYV3ForCausalLM` + +## Example HF Models + +| Model | HF ID | +|---|---| +| Hy3-preview | [`tencent/Hy3-preview`](https://huggingface.co/tencent/Hy3-preview) | + +## Example Recipes + +| Recipe | Description | +|---|---| +| {download}`hy3_preview_deepep.yaml <../../../../examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml>` | SFT — Hy3-preview with DeepEP | + +## Try with NeMo AutoModel + +**1. Install** ([NeMo AutoModel](../../../guides/installation.md)): + +```bash +pip install nemo-automodel +``` + +**2. Clone the repo** to get the example recipes: + +```bash +git clone https://github.com/NVIDIA-NeMo/Automodel.git +cd Automodel +``` + +**3. Run the recipe** from inside the repo: + +```bash +automodel --nproc-per-node=8 examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml +``` + +See the [NeMo AutoModel Installation Guide](../../../guides/installation.md) and [LLM Fine-Tuning Guide](../../../guides/llm/finetune.md). + +## Fine-Tuning + +See the [LLM Fine-Tuning Guide](../../../guides/llm/finetune.md) and the [Large MoE Fine-Tuning Guide](../../../guides/llm/large-moe-finetune.md). + +## Hugging Face Model Cards + +- [tencent/Hy3-preview](https://huggingface.co/tencent/Hy3-preview) diff --git a/examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml b/examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml new file mode 100644 index 000000000..b9ce4553b --- /dev/null +++ b/examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml @@ -0,0 +1,134 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SFT recipe for tencent/Hy3-preview (295B MoE, 192 experts top-8, 256K context). +# Requires transformers >= 5.6.0 for the hy_v3 architecture. +# +# Hardware: 8 GPUs (80 GB+ each) for LoRA; 32 GPUs for full fine-tuning. +# automodel examples/llm_finetune/hy_v3/hy3_preview_deepep.yaml --nproc-per-node 8 +# +# EP size must divide num_experts (192): e.g. ep_size=8 gives 24 experts/rank. + +recipe: TrainFinetuneRecipeForNextTokenPrediction + +step_scheduler: + global_batch_size: 256 + local_batch_size: 8 + ckpt_every_steps: 500 + val_every_steps: 500 + num_epochs: 1 + max_steps: 100 + +dist_env: + backend: nccl + timeout_minutes: 30 + +rng: + _target_: nemo_automodel.components.training.rng.StatefulRNG + seed: 1111 + ranked: true + +model: + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained + pretrained_model_name_or_path: tencent/Hy3-preview + torch_dtype: bfloat16 + backend: + _target_: nemo_automodel.components.models.common.BackendConfig + attn: te + linear: torch + rms_norm: torch_fp32 + experts: torch_mm + dispatcher: deepep + fake_balanced_gate: false + gate_precision: float32 + enable_hf_state_dict_adapter: true + enable_fsdp_optimizations: true + +checkpoint: + enabled: true + checkpoint_dir: /tmp/checkpoints/hy3_preview/ + model_save_format: safetensors + save_consolidated: true + +distributed: + strategy: fsdp2 + tp_size: 1 + cp_size: 1 + pp_size: 4 + ep_size: 32 # 192 experts / 8 ranks = 24 experts per rank + + sequence_parallel: false + activation_checkpointing: true + + pipeline: + pp_schedule: 1f1b + pp_microbatch_size: 1 + round_virtual_stages_to_pp_multiple: down + scale_grads_in_schedule: false + patch_inner_model: false + patch_causal_lm_model: false + + moe: + reshard_after_forward: false + wrap_outer_model: false + +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag + path_or_dataset: rowan/hellaswag + split: train + tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: tencent/Hy3-preview + +packed_sequence: + packed_sequence_size: 0 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: + _target_: nemo_automodel.components.datasets.utils.default_collater + pad_seq_len_divisible: 64 + shuffle: true + +validation_dataset: + _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag + path_or_dataset: rowan/hellaswag + split: validation + num_samples_limit: 64 + tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: tencent/Hy3-preview + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: + _target_: nemo_automodel.components.datasets.utils.default_collater + pad_seq_len_divisible: 64 + shuffle: false + drop_last: true + +optimizer: + _target_: torch.optim.AdamW + betas: [0.9, 0.95] + eps: 1e-8 + lr: 1e-5 + weight_decay: 0.0 + +# Uncomment for W&B logging +# wandb: +# project: hy3-preview-sft +# name: hy3_preview_deepep diff --git a/nemo_automodel/_transformers/registry.py b/nemo_automodel/_transformers/registry.py index 1a0015560..0c1eaa44d 100644 --- a/nemo_automodel/_transformers/registry.py +++ b/nemo_automodel/_transformers/registry.py @@ -141,6 +141,10 @@ "LLaVAOneVision1_5_ForConditionalGeneration", ), ), + ( + "HYV3ForCausalLM", + ("nemo_automodel.components.models.hy_v3.model", "HYV3ForCausalLM"), + ), ( "Qwen2ForCausalLM", ("nemo_automodel.components.models.qwen2.model", "Qwen2ForCausalLM"), @@ -182,6 +186,7 @@ _CUSTOM_CONFIG_REGISTRATIONS: Dict[str, Tuple[str, str]] = { "baichuan": ("nemo_automodel.components.models.baichuan.configuration", "BaichuanConfig"), "deepseek_v4": ("nemo_automodel.components.models.deepseek_v4.config", "DeepseekV4Config"), + "hy_v3": ("nemo_automodel.components.models.hy_v3.config", "HYV3Config"), "kimi_k25": ("nemo_automodel.components.models.kimi_k25_vl.model", "KimiK25VLConfig"), "kimi_vl": ("nemo_automodel.components.models.kimivl.model", "KimiVLConfig"), "llavaonevision1_5": ("nemo_automodel.components.models.llava_onevision.model", "Llavaonevision1_5Config"), diff --git a/nemo_automodel/components/models/hy_v3/__init__.py b/nemo_automodel/components/models/hy_v3/__init__.py new file mode 100644 index 000000000..0aba0c161 --- /dev/null +++ b/nemo_automodel/components/models/hy_v3/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_automodel.components.models.hy_v3.model import HYV3ForCausalLM + +__all__ = ["HYV3ForCausalLM"] diff --git a/nemo_automodel/components/models/hy_v3/config.py b/nemo_automodel/components/models/hy_v3/config.py new file mode 100644 index 000000000..479ce3a7f --- /dev/null +++ b/nemo_automodel/components/models/hy_v3/config.py @@ -0,0 +1,100 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from transformers import PretrainedConfig + + +class HYV3Config(PretrainedConfig): + """Configuration class for Tencent Hy3-preview (295B MoE). + + Architecture: + - 80 transformer layers; layer 0 is dense, layers 1-79 are MoE + - MoE: 192 routed experts + 1 shared expert, top-8 activated + - Sigmoid routing with expert-bias correction (e_score_correction_bias) + - GQA: 64 Q heads, 8 KV heads, head_dim=128 + - Per-head QK RMSNorm before RoPE + - 256K context, rope_theta=11158840 + """ + + model_type = "hy_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 129280, + hidden_size: int = 4096, + intermediate_size: int = 1536, + moe_intermediate_size: int = 1536, + num_hidden_layers: int = 80, + num_attention_heads: int = 64, + num_key_value_heads: int = 8, + head_dim: int = 128, + # MoE routing + num_experts: int = 192, + num_shared_experts: int = 1, + num_experts_per_tok: int = 8, + router_scaling_factor: float = 1.0, + route_norm: bool = False, + moe_router_enable_expert_bias: bool = True, + # Dense layers + first_k_dense_replace: int = 1, + # Position encoding + max_position_embeddings: int = 262144, + rope_theta: float = 11158840.0, + rope_scaling: dict | None = None, + # Standard options + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + hidden_act: str = "silu", + use_cache: bool = True, + pad_token_id: int | None = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + torch_dtype: str = "bfloat16", + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.num_experts = num_experts + self.num_shared_experts = num_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.router_scaling_factor = router_scaling_factor + self.route_norm = route_norm + self.moe_router_enable_expert_bias = moe_router_enable_expert_bias + self.first_k_dense_replace = first_k_dense_replace + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.hidden_act = hidden_act + self.torch_dtype = torch_dtype + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + use_cache=use_cache, + **kwargs, + ) diff --git a/nemo_automodel/components/models/hy_v3/layers.py b/nemo_automodel/components/models/hy_v3/layers.py new file mode 100644 index 000000000..c0865dbf0 --- /dev/null +++ b/nemo_automodel/components/models/hy_v3/layers.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +from torch import nn + +from nemo_automodel.components.attention.utils import ( + initialize_attn_module_and_func, + postprocess_output_for_attn, + preprocess_args_and_kwargs_for_attn, +) +from nemo_automodel.components.models.common import ( + BackendConfig, + initialize_linear_module, + initialize_rms_norm_module, +) +from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk + + +class HYV3Attention(nn.Module): + """HYV3 attention with GQA, per-head Q/K RMSNorm, and RoPE. + + Architecture: + - q_proj: [hidden, n_heads * head_dim] + - k_proj / v_proj: [hidden, n_kv_heads * head_dim] + - q_norm / k_norm: RMSNorm applied per-head before RoPE + - RoPE applied after per-head norm + """ + + def __init__(self, config: Any, backend: BackendConfig): + super().__init__() + self.backend = backend + + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) + + attention_bias = getattr(config, "attention_bias", False) + + self.q_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias + ) + self.k_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + ) + self.v_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + ) + self.o_proj = initialize_linear_module( + backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias + ) + + # Per-head RMSNorm on Q and K (qk_norm=true in HYV3 config) + self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + + softmax_scale = self.head_dim**-0.5 + self.attn_module, self.attn_func = initialize_attn_module_and_func( + attn_impl=backend.attn, + num_attention_heads=self.num_heads, + num_qk_channels=self.head_dim, + num_v_channels=self.head_dim, + softmax_scale=softmax_scale, + num_gqa_groups=self.num_kv_heads, + ) + + def forward( + self, + x: torch.Tensor, + *, + freqs_cis: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if len(x.shape) == 2: + qkv_format = "thd" + num_tokens = x.shape[0] + else: + qkv_format = "bshd" + bsz, seqlen, _ = x.size() + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + if qkv_format == "thd": + q = q.view(num_tokens, self.num_heads, self.head_dim) + k = k.view(num_tokens, self.num_kv_heads, self.head_dim) + v = v.view(num_tokens, self.num_kv_heads, self.head_dim) + else: + q = q.view(bsz, seqlen, self.num_heads, self.head_dim) + k = k.view(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.num_kv_heads, self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(k) + + q, k = apply_rotary_emb_qk( + q, + k, + freqs_cis, + format=qkv_format, + rope_fusion=self.backend.rope_fusion, + cu_seqlens=attn_kwargs.get("cu_seqlens", None), + cp_size=attn_kwargs.get("cp_size", 1), + cp_rank=attn_kwargs.get("cp_rank", 0), + ) + + q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn( + q, k, v, attention_mask, self.backend.attn, **attn_kwargs + ) + out = self.attn_func(q, k, v, **_attn_kwargs) + out = postprocess_output_for_attn(out, self.backend.attn) + + flatten_dim = 2 if qkv_format == "bshd" else 1 + out = self.o_proj(out.flatten(flatten_dim)) + return out + + def init_weights(self, buffer_device: torch.device, init_std: float = 0.02): + for linear in (self.q_proj, self.k_proj, self.v_proj, self.o_proj): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + if hasattr(linear, "bias") and linear.bias is not None: + nn.init.zeros_(linear.bias) + for norm in (self.q_norm, self.k_norm): + norm.reset_parameters() diff --git a/nemo_automodel/components/models/hy_v3/model.py b/nemo_automodel/components/models/hy_v3/model.py new file mode 100644 index 000000000..09bc496f4 --- /dev/null +++ b/nemo_automodel/components/models/hy_v3/model.py @@ -0,0 +1,332 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HYV3ForCausalLM — Tencent Hy3-preview (295B MoE) SFT support. + +Architecture (from tencent/Hy3-preview config.json): + - 80 transformer layers; layer 0 is dense, layers 1-79 are MoE + - MoE: 192 routed experts + 1 shared expert, top-8 activated + - Sigmoid routing with expert-bias correction (e_score_correction_bias) + - GQA: 64 Q heads, 8 KV heads, head_dim=128 + - Per-head QK RMSNorm before RoPE + - 256K context, rope_theta=11158840 +""" + +from typing import Any + +import torch +import torch.nn as nn + +from nemo_automodel.components.models.common import ( + BackendConfig, + get_rope_config, + initialize_linear_module, + initialize_rms_norm_module, +) +from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin +from nemo_automodel.components.models.common.utils import cast_model_to_dtype +from nemo_automodel.components.models.gpt_oss.rope_utils import RotaryEmbedding, position_ids_to_freqs_cis +from nemo_automodel.components.models.hy_v3.layers import HYV3Attention +from nemo_automodel.components.models.hy_v3.state_dict_adapter import HYV3StateDictAdapter +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.fsdp_mixin import MoEFSDPSyncMixin +from nemo_automodel.components.moe.layers import MLP, MoE +from nemo_automodel.components.utils.model_utils import squeeze_input_for_thd +from nemo_automodel.shared.utils import dtype_from_str as get_dtype + + +class Block(nn.Module): + def __init__(self, layer_idx: int, config: Any, moe_config: MoEConfig, backend: BackendConfig): + super().__init__() + self.self_attn = HYV3Attention(config, backend) + + # Layers 0..(first_k_dense_replace-1) are dense; the rest are MoE. + first_k_dense = getattr(config, "first_k_dense_replace", 1) + if layer_idx < first_k_dense: + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + else: + self.mlp = MoE(moe_config, backend) + + self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + ) + self.layer_idx = layer_idx + + def forward( + self, + x: torch.Tensor, + *, + freqs_cis: torch.Tensor, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if attention_mask is not None and padding_mask is None: + padding_mask = attention_mask.bool().logical_not() + + attn_out = self.self_attn( + x=self.input_layernorm(x), + freqs_cis=freqs_cis, + attention_mask=attention_mask, + **attn_kwargs, + ) + x = x + attn_out + + mlp_out = self._mlp(x=self.post_attention_layernorm(x), padding_mask=padding_mask) + x = x + mlp_out + return x + + def _mlp(self, x: torch.Tensor, padding_mask: torch.Tensor | None) -> torch.Tensor: + if isinstance(self.mlp, MLP): + return self.mlp(x) + assert isinstance(self.mlp, MoE) + return self.mlp(x, padding_mask) + + def init_weights(self, buffer_device: torch.device): + for norm in (self.input_layernorm, self.post_attention_layernorm): + norm.reset_parameters() + self.self_attn.init_weights(buffer_device) + self.mlp.init_weights(buffer_device) + + +class HYV3Model(nn.Module): + def __init__( + self, + config: Any, + backend: BackendConfig, + *, + moe_config: MoEConfig | None = None, + moe_overrides: dict | None = None, + ): + super().__init__() + self.backend = backend + self.config = config + if moe_config is not None and moe_overrides is not None: + raise ValueError("Cannot pass both moe_config and moe_overrides.") + + moe_defaults = dict( + dim=config.hidden_size, + inter_dim=config.intermediate_size, + moe_inter_dim=config.moe_intermediate_size, + n_routed_experts=config.num_experts, + n_shared_experts=getattr(config, "num_shared_experts", 0), + n_activated_experts=config.num_experts_per_tok, + n_expert_groups=0, + n_limited_groups=0, + train_gate=True, + gate_bias_update_factor=0.0, + score_func="sigmoid", + route_scale=getattr(config, "router_scaling_factor", 1.0), + aux_loss_coeff=0.0, + norm_topk_prob=getattr(config, "route_norm", False), + expert_bias=False, + router_bias=False, + expert_activation="swiglu", + softmax_before_topk=False, + # Ensures e_score_correction_bias buffer is created so HF checkpoints load cleanly + force_e_score_correction_bias=getattr(config, "moe_router_enable_expert_bias", False), + ) + if moe_overrides: + moe_defaults.update(moe_overrides) + self.moe_config = moe_config or MoEConfig(**moe_defaults) + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + ) + self.layers = torch.nn.ModuleDict() + for layer_id in range(config.num_hidden_layers): + self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) + self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + + self.max_seq_len = config.max_position_embeddings + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + base, rope_scaling, _ = get_rope_config(config) + + self.rotary_emb = RotaryEmbedding( + head_dim=self.head_dim, + base=base, + dtype=torch.float32, + initial_context_length=rope_scaling.get("original_max_position_embeddings", 4096), + scaling_factor=rope_scaling.get("factor", 1.0), + ntk_alpha=rope_scaling.get("beta_slow", 1.0), + ntk_beta=rope_scaling.get("beta_fast", 32.0), + device=torch.device(f"cuda:{torch.cuda.current_device()}"), + ) + + def forward( + self, + input_ids: torch.Tensor, + *, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if position_ids is None: + position_ids = ( + torch.arange(0, input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1) + ) + + freqs_cis = position_ids_to_freqs_cis( + self.rotary_emb, + position_ids, + qkv_format=attn_kwargs.get("qkv_format", "bshd"), + for_fused_rope=self.backend.rope_fusion, + cp_size=attn_kwargs.get("cp_size", 1), + ) + + h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids + + for layer in self.layers.values(): + h = layer( + x=h, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + padding_mask=padding_mask, + **attn_kwargs, + ) + + h = self.norm(h) if self.norm else h + return h + + @torch.no_grad() + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + if self.embed_tokens is not None: + nn.init.normal_(self.embed_tokens.weight) + if self.norm is not None: + self.norm.reset_parameters() + self.rotary_emb.device = buffer_device + + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + + +class HYV3ForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin): + @classmethod + def from_config( + cls, + config: Any, + moe_config: MoEConfig | None = None, + backend: BackendConfig | None = None, + **kwargs, + ): + return cls(config, moe_config, backend, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + *model_args, + **kwargs, + ): + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=False) + return cls.from_config(config, *model_args, **kwargs) + + def __init__( + self, + config: Any, + moe_config: MoEConfig | None = None, + backend: BackendConfig | None = None, + **kwargs, + ): + super().__init__() + self.config = config + self.backend = backend or BackendConfig() + moe_overrides = kwargs.pop("moe_overrides", None) + self.model = HYV3Model(config, backend=self.backend, moe_config=moe_config, moe_overrides=moe_overrides) + self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + if self.backend.enable_hf_state_dict_adapter: + self.state_dict_adapter = HYV3StateDictAdapter( + self.config, + self.model.moe_config, + self.backend, + dtype=get_dtype(config.torch_dtype, torch.bfloat16), + ) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward( + self, + input_ids: torch.Tensor, + *, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd": + input_ids, position_ids, padding_mask, attn_kwargs = squeeze_input_for_thd( + input_ids, position_ids, padding_mask, attn_kwargs + ) + attention_mask = None + + hidden = self.model( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + padding_mask=padding_mask, + **attn_kwargs, + ) + logits = self.lm_head(hidden) if self.lm_head else hidden + if "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd": + logits = logits.unsqueeze(0) + return logits + + def update_moe_gate_bias(self) -> None: + with torch.no_grad(): + for block in self.model.layers.values(): + if isinstance(block.mlp, MoE) and block.mlp.gate.bias_update_factor > 0: + block.mlp.gate.update_bias() + + @torch.no_grad() + def initialize_weights( + self, buffer_device: torch.device | None = None, dtype: torch.dtype = torch.bfloat16 + ) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + self.model.init_weights(buffer_device=buffer_device) + final_out_std = self.config.hidden_size**-0.5 + cutoff_factor = 3 + if self.lm_head is not None: + nn.init.trunc_normal_( + self.lm_head.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + cast_model_to_dtype(self, dtype) + with buffer_device: + self.model.rotary_emb.device = buffer_device + + +ModelClass = HYV3ForCausalLM diff --git a/nemo_automodel/components/models/hy_v3/state_dict_adapter.py b/nemo_automodel/components/models/hy_v3/state_dict_adapter.py new file mode 100644 index 000000000..a75d77a91 --- /dev/null +++ b/nemo_automodel/components/models/hy_v3/state_dict_adapter.py @@ -0,0 +1,205 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State dict conversion between the on-disk tencent/Hy3-preview HF checkpoint +and Automodel's native (grouped-experts) format. + +On-disk HF format (what tencent/Hy3-preview safetensors actually contain): + model.layers.{L}.mlp.expert_bias # [n_experts] + model.layers.{L}.mlp.router.gate.weight # [n_experts, hidden] + model.layers.{L}.mlp.experts.{E}.gate_proj.weight # [moe_inter, hidden] + model.layers.{L}.mlp.experts.{E}.up_proj.weight # [moe_inter, hidden] + model.layers.{L}.mlp.experts.{E}.down_proj.weight # [hidden, moe_inter] + model.layers.{L}.mlp.shared_mlp.{gate,up,down}_proj.weight # [moe_inter, hidden] / [hidden, moe_inter] + +Automodel native format (matches the rest of the MoE stack): + model.layers.{L}.mlp.gate.e_score_correction_bias # [n_local] (on Gate, not MoE) + model.layers.{L}.mlp.gate.weight # [n_experts, hidden] + model.layers.{L}.mlp.experts.gate_and_up_projs # [n_local, hidden, 2*moe_inter] + model.layers.{L}.mlp.experts.down_projs # [n_local, moe_inter, hidden] + model.layers.{L}.mlp.shared_experts.{gate,up,down}_proj.weight # unchanged shapes + +Differences (vs. every other Automodel MoE adapter): + 1. Per-expert split tensors -> grouped (handled by MoESplitExpertsStateDictMixin). + 2. Three HYV3-specific name renames: expert_bias <-> gate.e_score_correction_bias, + router.gate.weight <-> gate.weight, shared_mlp.* <-> shared_experts.*. + 3. MTP layers (indices >= num_hidden_layers) on disk must be filtered out on load. + +Why the renames live in the adapter rather than in the storage reader's key_mapping: +nemo_automodel/components/checkpoint/checkpointing.py:507 deliberately passes +``reader_key_mapping=None`` when a model has a state_dict_adapter (to avoid +double-translation). So the adapter's ``to_hf`` / ``from_hf`` must produce keys +that match the actual on-disk strings. +""" + +import logging +import re +from typing import Any, Optional + +import torch +from torch.distributed.device_mesh import DeviceMesh + +from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter +from nemo_automodel.components.models.common import BackendConfig +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.state_dict_mixin import MoESplitExpertsStateDictMixin + +logger = logging.getLogger(__name__) + + +# Pre-compiled HYV3-specific name renames (anchored to end-of-key for safety). +_NATIVE_TO_HF_RENAMES: tuple[tuple[re.Pattern[str], str], ...] = ( + (re.compile(r"\.mlp\.gate\.e_score_correction_bias$"), ".mlp.expert_bias"), + (re.compile(r"\.mlp\.gate\.weight$"), ".mlp.router.gate.weight"), + (re.compile(r"\.mlp\.shared_experts\."), ".mlp.shared_mlp."), +) +_HF_TO_NATIVE_RENAMES: tuple[tuple[re.Pattern[str], str], ...] = ( + (re.compile(r"\.mlp\.expert_bias$"), ".mlp.gate.e_score_correction_bias"), + (re.compile(r"\.mlp\.router\.gate\.weight$"), ".mlp.gate.weight"), + (re.compile(r"\.mlp\.shared_mlp\."), ".mlp.shared_experts."), +) + + +class HYV3StateDictAdapter(MoESplitExpertsStateDictMixin, StateDictAdapter): + """Bridges Automodel native (grouped experts) and tencent/Hy3-preview on-disk HF. + + Inherits the per-expert split/merge logic from ``MoESplitExpertsStateDictMixin``; + only the three HYV3-specific name renames + MTP-layer filtering live here. + """ + + def __init__( + self, + config: Any, + moe_config: MoEConfig, + backend: BackendConfig, + dtype: torch.dtype = torch.bfloat16, + ): + self.config = config + self.moe_config = moe_config + self.backend = backend + self.dtype = dtype + self._uses_model_prefix = True + + # ------------------------------------------------------------------ + # Native -> on-disk HF + # ------------------------------------------------------------------ + + def to_hf( + self, + state_dict: dict[str, Any], + exclude_key_regex: Optional[str] = None, + **kwargs, + ) -> dict[str, Any]: + """Convert native state dict back to the on-disk Tencent format. + + Steps: + 1. Split grouped expert tensors into per-expert HF keys (mixin). + 2. Apply HYV3 name renames (gate.e_score_correction_bias -> expert_bias, + gate.weight -> router.gate.weight, shared_experts. -> shared_mlp.). + """ + # Step 1: per-expert split via the mixin. Pass-through for non-expert keys. + hf_split: dict[str, Any] = self._to_hf_w_split_experts(state_dict) + + # Step 2: rename native -> on-disk Tencent. + out: dict[str, Any] = {} + for k, v in hf_split.items(): + new_k = k + for pat, repl in _NATIVE_TO_HF_RENAMES: + new_k, n = pat.subn(repl, new_k) + if n: + break + if exclude_key_regex and re.match(exclude_key_regex, new_k): + continue + out[new_k] = v + return out + + # ------------------------------------------------------------------ + # On-disk HF -> native + # ------------------------------------------------------------------ + + def from_hf( + self, + hf_state_dict: dict[str, Any], + device_mesh: Optional[DeviceMesh] = None, + **kwargs, + ) -> dict[str, Any]: + """Convert the on-disk Tencent state dict to native format. + + Steps: + 1. Drop MTP (multi-token prediction) layer keys. + 2. Apply HYV3 name renames (on-disk -> native HF naming). + 3. Merge per-expert split tensors into grouped form via the mixin + (validates expert availability against the rank's EP slice). + """ + # Step 1 + 2: filter MTP, rename to native names, in a single pass. + renamed: dict[str, Any] = {} + for k, v in hf_state_dict.items(): + if self._is_mtp_key(k): + continue + new_k = k + for pat, repl in _HF_TO_NATIVE_RENAMES: + new_k, n = pat.subn(repl, new_k) + if n: + break + renamed[new_k] = v + + # Step 3: per-expert merge + EP slicing via the mixin. + return self._from_hf_w_merged_experts(renamed, device_mesh) + + # ------------------------------------------------------------------ + # Single-tensor variant required by the abstract base class. + # ------------------------------------------------------------------ + + def convert_single_tensor_to_hf( + self, + fqn: str, + tensor: Any, + **kwargs, + ) -> list[tuple[str, Any]]: + """Per-tensor variant of ``to_hf`` (used by save paths that stream tensors). + + Mirrors ``to_hf`` but operating on one (fqn, tensor) at a time: + 1. Try the mixin's per-expert split. Returns multiple (key, tensor) pairs + when *fqn* names a grouped expert tensor; otherwise returns ``None``. + 2. Apply HYV3 name renames to whichever key set we end up with. + """ + exclude_key_regex = kwargs.get("exclude_key_regex", None) + + expert_split = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor, **kwargs) + if expert_split is not None: + pairs = expert_split + else: + pairs = [(fqn, tensor)] + + out: list[tuple[str, Any]] = [] + for k, v in pairs: + new_k = k + for pat, repl in _NATIVE_TO_HF_RENAMES: + new_k, n = pat.subn(repl, new_k) + if n: + break + if exclude_key_regex and re.match(exclude_key_regex, new_k): + continue + out.append((new_k, v)) + return out + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _is_mtp_key(self, key: str) -> bool: + """Return True if *key* belongs to an MTP layer (index >= num_hidden_layers).""" + num_hidden = getattr(self.config, "num_hidden_layers", 80) + m = re.match(r"(?:model\.)?layers\.(\d+)\.", key) + return bool(m and int(m.group(1)) >= num_hidden) diff --git a/tests/unit_tests/_transformers/test_registry_hy_v3.py b/tests/unit_tests/_transformers/test_registry_hy_v3.py new file mode 100644 index 000000000..9e733cccc --- /dev/null +++ b/tests/unit_tests/_transformers/test_registry_hy_v3.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Verify HYV3 model + config are registered in nemo_automodel._transformers.registry.""" + +import pytest + + +class TestArchMapping: + def test_hyv3_arch_registered(self): + from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + + assert "HYV3ForCausalLM" in MODEL_ARCH_MAPPING + + def test_hyv3_arch_points_at_correct_module(self): + from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + + entry = MODEL_ARCH_MAPPING["HYV3ForCausalLM"] + assert entry[0] == "nemo_automodel.components.models.hy_v3.model" + assert entry[1] == "HYV3ForCausalLM" + + def test_hyv3_arch_resolves_to_class(self): + """Walk the mapping path -- importable + the named class exists.""" + import importlib + + from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + + mod_path, cls_name, *_ = MODEL_ARCH_MAPPING["HYV3ForCausalLM"] + mod = importlib.import_module(mod_path) + assert hasattr(mod, cls_name) + + +class TestCustomConfigRegistration: + def test_hy_v3_config_registered(self): + from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_REGISTRATIONS + + assert "hy_v3" in _CUSTOM_CONFIG_REGISTRATIONS + + def test_hy_v3_config_resolves_to_class(self): + import importlib + + from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_REGISTRATIONS + + mod_path, cls_name = _CUSTOM_CONFIG_REGISTRATIONS["hy_v3"] + mod = importlib.import_module(mod_path) + cls = getattr(mod, cls_name) + assert cls.__name__ == "HYV3Config" + assert cls.model_type == "hy_v3" + + +class TestSupportedBackbonesIntact: + """Sanity check that hy_v3 registration didn't disturb existing backbones.""" + + def test_llama_still_in_supported_backbones(self): + from nemo_automodel._transformers.retrieval import SUPPORTED_BACKBONES + + assert "llama" in SUPPORTED_BACKBONES diff --git a/tests/unit_tests/models/hy_v3/test_hy_v3_config.py b/tests/unit_tests/models/hy_v3/test_hy_v3_config.py new file mode 100644 index 000000000..5ebebb031 --- /dev/null +++ b/tests/unit_tests/models/hy_v3/test_hy_v3_config.py @@ -0,0 +1,139 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``HYV3Config``.""" + +import pytest +from transformers import PretrainedConfig + +from nemo_automodel.components.models.hy_v3.config import HYV3Config + + +class TestDefaults: + def test_model_type(self): + assert HYV3Config.model_type == "hy_v3" + + def test_inherits_pretrained_config(self): + cfg = HYV3Config() + assert isinstance(cfg, PretrainedConfig) + + def test_default_attributes_match_295b(self): + cfg = HYV3Config() + # Architecture defaults from the published Hy3-preview spec. + assert cfg.vocab_size == 129280 + assert cfg.hidden_size == 4096 + assert cfg.intermediate_size == 1536 + assert cfg.moe_intermediate_size == 1536 + assert cfg.num_hidden_layers == 80 + assert cfg.num_attention_heads == 64 + assert cfg.num_key_value_heads == 8 + assert cfg.head_dim == 128 + assert cfg.num_experts == 192 + assert cfg.num_shared_experts == 1 + assert cfg.num_experts_per_tok == 8 + assert cfg.first_k_dense_replace == 1 + assert cfg.max_position_embeddings == 262144 + assert cfg.rope_theta == 11158840.0 + assert cfg.rms_norm_eps == 1e-6 + assert cfg.attention_bias is False + assert cfg.hidden_act == "silu" + # torch_dtype is auto-coerced by PretrainedConfig (deprecated -> dtype); + # accept either the string we set or whatever the base class normalizes to. + assert cfg.torch_dtype in ("bfloat16", None) or str(cfg.torch_dtype).endswith("bfloat16") + assert cfg.tie_word_embeddings is False + assert cfg.moe_router_enable_expert_bias is True + + def test_keys_to_ignore_at_inference(self): + assert HYV3Config.keys_to_ignore_at_inference == ["past_key_values"] + + +class TestOverrides: + def test_override_attention_dims(self): + cfg = HYV3Config( + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_size=512, + ) + assert cfg.num_attention_heads == 8 + assert cfg.num_key_value_heads == 2 + assert cfg.head_dim == 64 + assert cfg.hidden_size == 512 + + def test_override_moe_routing(self): + cfg = HYV3Config(num_experts=64, num_experts_per_tok=4, num_shared_experts=2, router_scaling_factor=1.5) + assert cfg.num_experts == 64 + assert cfg.num_experts_per_tok == 4 + assert cfg.num_shared_experts == 2 + assert cfg.router_scaling_factor == 1.5 + + def test_truncated_layer_count(self): + cfg = HYV3Config(num_hidden_layers=4) + assert cfg.num_hidden_layers == 4 + + def test_first_k_dense_replace(self): + cfg = HYV3Config(first_k_dense_replace=3) + assert cfg.first_k_dense_replace == 3 + + def test_router_flags(self): + cfg = HYV3Config(route_norm=True, moe_router_enable_expert_bias=False) + assert cfg.route_norm is True + assert cfg.moe_router_enable_expert_bias is False + + def test_rope_overrides(self): + cfg = HYV3Config(rope_theta=500000.0, max_position_embeddings=4096) + assert cfg.rope_theta == 500000.0 + assert cfg.max_position_embeddings == 4096 + + def test_rope_scaling_dict(self): + scaling = {"factor": 8.0, "rope_type": "yarn"} + cfg = HYV3Config(rope_scaling=scaling) + assert cfg.rope_scaling == scaling + + def test_token_ids(self): + cfg = HYV3Config(pad_token_id=0, bos_token_id=10, eos_token_id=11) + assert cfg.pad_token_id == 0 + assert cfg.bos_token_id == 10 + assert cfg.eos_token_id == 11 + + def test_super_init_kwargs_accepted(self): + # Verify that PretrainedConfig-recognized kwargs (here: use_cache, + # tie_word_embeddings) flow through __init__ without raising. + HYV3Config(use_cache=False, tie_word_embeddings=True) + + def test_extra_kwargs_pass_through_super_init(self): + # PretrainedConfig **kwargs in newer transformers no longer attaches + # arbitrary fields to the instance, but the call should still succeed. + cfg = HYV3Config(custom_field="abc") + assert isinstance(cfg, HYV3Config) + + +class TestSerialization: + def test_to_dict_round_trip(self): + cfg = HYV3Config(num_hidden_layers=4, num_experts=8, hidden_size=256) + d = cfg.to_dict() + assert d["model_type"] == "hy_v3" + assert d["num_hidden_layers"] == 4 + assert d["num_experts"] == 8 + + rebuilt = HYV3Config(**{k: v for k, v in d.items() if k != "model_type"}) + assert rebuilt.num_hidden_layers == 4 + assert rebuilt.num_experts == 8 + assert rebuilt.hidden_size == 256 + + def test_model_type_class_attribute_not_overridden_by_instance(self): + cfg = HYV3Config() + # model_type is a class-level attribute that AutoConfig dispatches on. + assert cfg.model_type == "hy_v3" + assert HYV3Config.model_type == "hy_v3" diff --git a/tests/unit_tests/models/hy_v3/test_hy_v3_layers.py b/tests/unit_tests/models/hy_v3/test_hy_v3_layers.py new file mode 100644 index 000000000..7a3eb5468 --- /dev/null +++ b/tests/unit_tests/models/hy_v3/test_hy_v3_layers.py @@ -0,0 +1,185 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``HYV3Attention``.""" + +from unittest.mock import patch + +import pytest +import torch + +from nemo_automodel.components.models.common import BackendConfig +from nemo_automodel.components.models.hy_v3.config import HYV3Config +from nemo_automodel.components.models.hy_v3.layers import HYV3Attention + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +HIDDEN = 64 +N_HEADS = 8 +N_KV = 2 +HEAD_DIM = 16 + + +@pytest.fixture +def device(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + + +@pytest.fixture +def config(): + return HYV3Config( + vocab_size=128, + hidden_size=HIDDEN, + intermediate_size=128, + moe_intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=N_HEADS, + num_key_value_heads=N_KV, + head_dim=HEAD_DIM, + max_position_embeddings=128, + rope_theta=10000.0, + rms_norm_eps=1e-6, + ) + + +@pytest.fixture +def sdpa_backend(): + return BackendConfig( + linear="torch", + attn="sdpa", + rms_norm="torch", + experts="torch", + dispatcher="torch", + fake_balanced_gate=False, + enable_hf_state_dict_adapter=False, + rope_fusion=False, + ) + + +def _make_freqs_cis(seq_len: int, device: torch.device) -> torch.Tensor: + """Synthesize a freqs_cis tensor of shape (1, seq, head_dim) — matches + ``apply_rotary_emb_qk(format='bshd', rope_fusion=False)`` which splits + [cos, sin] along the last dim.""" + return torch.zeros(1, seq_len, HEAD_DIM, device=device) + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +class TestInit: + def test_module_attributes(self, config, sdpa_backend): + attn = HYV3Attention(config, backend=sdpa_backend) + assert attn.num_heads == N_HEADS + assert attn.num_kv_heads == N_KV + assert attn.head_dim == HEAD_DIM + assert attn.backend is sdpa_backend + + def test_projection_shapes(self, config, sdpa_backend): + attn = HYV3Attention(config, backend=sdpa_backend) + assert attn.q_proj.weight.shape == (N_HEADS * HEAD_DIM, HIDDEN) + assert attn.k_proj.weight.shape == (N_KV * HEAD_DIM, HIDDEN) + assert attn.v_proj.weight.shape == (N_KV * HEAD_DIM, HIDDEN) + assert attn.o_proj.weight.shape == (HIDDEN, N_HEADS * HEAD_DIM) + + def test_q_k_norm_per_head_dim(self, config, sdpa_backend): + """qk_norm is applied per-head, so the norm weight is sized to head_dim.""" + attn = HYV3Attention(config, backend=sdpa_backend) + assert attn.q_norm.weight.shape == (HEAD_DIM,) + assert attn.k_norm.weight.shape == (HEAD_DIM,) + + def test_no_attention_bias_by_default(self, config, sdpa_backend): + attn = HYV3Attention(config, backend=sdpa_backend) + assert attn.q_proj.bias is None + assert attn.k_proj.bias is None + assert attn.v_proj.bias is None + assert attn.o_proj.bias is None + + def test_attention_bias_enabled(self, config, sdpa_backend): + config.attention_bias = True + attn = HYV3Attention(config, backend=sdpa_backend) + assert attn.q_proj.bias is not None + assert attn.k_proj.bias is not None + assert attn.v_proj.bias is not None + assert attn.o_proj.bias is not None + + +# --------------------------------------------------------------------------- +# Forward +# --------------------------------------------------------------------------- + + +class TestForward: + def test_output_shape_bshd(self, config, sdpa_backend, device): + attn = HYV3Attention(config, backend=sdpa_backend).to(device) + bsz, seqlen = 2, 4 + x = torch.randn(bsz, seqlen, HIDDEN, device=device, dtype=torch.bfloat16) + freqs = _make_freqs_cis(seqlen, device) + + out = attn(x, freqs_cis=freqs) + assert out.shape == (bsz, seqlen, HIDDEN) + + def test_calls_q_k_v_o_projections(self, config, sdpa_backend, device): + attn = HYV3Attention(config, backend=sdpa_backend).to(device) + bsz, seqlen = 1, 3 + x = torch.randn(bsz, seqlen, HIDDEN, device=device, dtype=torch.bfloat16) + freqs = _make_freqs_cis(seqlen, device) + + with patch.object(attn.q_proj, "forward", wraps=attn.q_proj.forward) as q, \ + patch.object(attn.k_proj, "forward", wraps=attn.k_proj.forward) as k, \ + patch.object(attn.v_proj, "forward", wraps=attn.v_proj.forward) as v, \ + patch.object(attn.o_proj, "forward", wraps=attn.o_proj.forward) as o: + attn(x, freqs_cis=freqs) + q.assert_called_once() + k.assert_called_once() + v.assert_called_once() + o.assert_called_once() + + def test_attention_mask_to_padding_mask_unused_in_sdpa(self, config, sdpa_backend, device): + """With sdpa+causal the attention mask is converted via preprocess but never + propagates to the kernel (is_causal=True suffices); we just check forward + completes without raising.""" + attn = HYV3Attention(config, backend=sdpa_backend).to(device) + bsz, seqlen = 1, 4 + x = torch.randn(bsz, seqlen, HIDDEN, device=device, dtype=torch.bfloat16) + freqs = _make_freqs_cis(seqlen, device) + mask = torch.ones(bsz, seqlen, dtype=torch.float32, device=device) + out = attn(x, freqs_cis=freqs, attention_mask=mask) + assert out.shape == x.shape + + +# --------------------------------------------------------------------------- +# Init weights +# --------------------------------------------------------------------------- + + +class TestInitWeights: + def test_resets_norms_and_linears(self, config, sdpa_backend, device): + attn = HYV3Attention(config, backend=sdpa_backend).to(device) + with patch.object(attn.q_norm, "reset_parameters") as qn, \ + patch.object(attn.k_norm, "reset_parameters") as kn: + attn.init_weights(buffer_device=device, init_std=0.01) + qn.assert_called_once() + kn.assert_called_once() + + def test_uses_provided_init_std(self, config, sdpa_backend, device): + attn = HYV3Attention(config, backend=sdpa_backend).to(device) + attn.init_weights(buffer_device=device, init_std=0.5) + # trunc_normal_ has the requested std (within stat tolerance for a small tensor). + # Just verify the linears were re-initialized to non-default values. + for lin in (attn.q_proj, attn.k_proj, attn.v_proj, attn.o_proj): + assert lin.weight.abs().max().item() > 0 diff --git a/tests/unit_tests/models/hy_v3/test_hy_v3_model.py b/tests/unit_tests/models/hy_v3/test_hy_v3_model.py new file mode 100644 index 000000000..1023d1412 --- /dev/null +++ b/tests/unit_tests/models/hy_v3/test_hy_v3_model.py @@ -0,0 +1,380 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the HYV3 Block / HYV3Model / HYV3ForCausalLM layers.""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from nemo_automodel.components.models.common import BackendConfig +from nemo_automodel.components.models.hy_v3.config import HYV3Config +from nemo_automodel.components.models.hy_v3.model import ( + Block, + HYV3ForCausalLM, + HYV3Model, + ModelClass, +) +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.layers import MLP, MoE + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +HIDDEN = 64 +INTER = 128 +MOE_INTER = 64 +N_HEADS = 8 +N_KV = 2 +HEAD_DIM = 16 +N_EXPERTS = 4 + + +@pytest.fixture +def device(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + + +@pytest.fixture +def config(): + return HYV3Config( + vocab_size=128, + hidden_size=HIDDEN, + intermediate_size=INTER, + moe_intermediate_size=MOE_INTER, + num_hidden_layers=2, + num_attention_heads=N_HEADS, + num_key_value_heads=N_KV, + head_dim=HEAD_DIM, + num_experts=N_EXPERTS, + num_experts_per_tok=2, + num_shared_experts=1, + first_k_dense_replace=1, + max_position_embeddings=128, + rope_theta=10000.0, + rms_norm_eps=1e-6, + ) + + +@pytest.fixture +def backend_config(): + return BackendConfig( + linear="torch", + attn="sdpa", + rms_norm="torch", + experts="torch", + dispatcher="torch", + fake_balanced_gate=False, + gate_precision="float32", + rope_fusion=False, + enable_hf_state_dict_adapter=False, + enable_fsdp_optimizations=False, + ) + + +@pytest.fixture +def moe_config(config): + return MoEConfig( + dim=config.hidden_size, + inter_dim=config.intermediate_size, + moe_inter_dim=config.moe_intermediate_size, + n_routed_experts=config.num_experts, + n_shared_experts=config.num_shared_experts, + n_activated_experts=config.num_experts_per_tok, + n_expert_groups=0, + n_limited_groups=0, + train_gate=True, + gate_bias_update_factor=0.0, + score_func="sigmoid", + route_scale=1.0, + aux_loss_coeff=0.0, + norm_topk_prob=False, + expert_bias=False, + router_bias=False, + expert_activation="swiglu", + softmax_before_topk=False, + force_e_score_correction_bias=True, + ) + + +# --------------------------------------------------------------------------- +# Block +# --------------------------------------------------------------------------- + + +class TestBlock: + def test_dense_layer_uses_mlp_when_idx_below_first_k_dense(self, config, moe_config, backend_config): + """layer_idx < first_k_dense_replace -> dense MLP, not MoE.""" + config.first_k_dense_replace = 1 + block = Block(layer_idx=0, config=config, moe_config=moe_config, backend=backend_config) + assert isinstance(block.mlp, MLP) + assert not isinstance(block.mlp, MoE) + + def test_moe_layer_uses_moe_when_idx_at_or_above_first_k_dense(self, config, moe_config, backend_config): + config.first_k_dense_replace = 1 + block = Block(layer_idx=1, config=config, moe_config=moe_config, backend=backend_config) + assert isinstance(block.mlp, MoE) + + def test_first_k_dense_replace_higher_threshold(self, config, moe_config, backend_config): + """If first_k_dense_replace=3, layers 0-2 are dense and 3+ are MoE.""" + config.first_k_dense_replace = 3 + for i in (0, 1, 2): + block = Block(layer_idx=i, config=config, moe_config=moe_config, backend=backend_config) + assert isinstance(block.mlp, MLP), f"layer {i} should be dense" + block = Block(layer_idx=3, config=config, moe_config=moe_config, backend=backend_config) + assert isinstance(block.mlp, MoE) + + def test_block_has_required_submodules(self, config, moe_config, backend_config): + block = Block(layer_idx=1, config=config, moe_config=moe_config, backend=backend_config) + assert hasattr(block, "self_attn") + assert hasattr(block, "mlp") + assert hasattr(block, "input_layernorm") + assert hasattr(block, "post_attention_layernorm") + assert block.layer_idx == 1 + + def test_forward_residual_calls_attn_then_mlp(self, config, moe_config, backend_config, device): + block = Block(layer_idx=0, config=config, moe_config=moe_config, backend=backend_config).to(device) + bsz, seq = 2, 4 + x = torch.randn(bsz, seq, HIDDEN, device=device, dtype=torch.bfloat16) + freqs = torch.zeros(1, seq, HEAD_DIM, device=device) + with patch.object(block.self_attn, "forward", return_value=torch.zeros_like(x)) as mock_attn, \ + patch.object(block, "_mlp", return_value=torch.zeros_like(x)) as mock_mlp: + out = block(x, freqs_cis=freqs) + assert out.shape == x.shape + mock_attn.assert_called_once() + mock_mlp.assert_called_once() + + def test_padding_mask_built_from_attention_mask(self, config, moe_config, backend_config, device): + block = Block(layer_idx=0, config=config, moe_config=moe_config, backend=backend_config).to(device) + x = torch.randn(1, 3, HIDDEN, device=device, dtype=torch.bfloat16) + freqs = torch.zeros(1, 3, HEAD_DIM, device=device) + mask = torch.tensor([[1, 1, 0]], dtype=torch.bool, device=device) + with patch.object(block.self_attn, "forward", return_value=torch.zeros_like(x)), \ + patch.object(block, "_mlp", return_value=torch.zeros_like(x)) as mock_mlp: + block(x, freqs_cis=freqs, attention_mask=mask) + _, kwargs = mock_mlp.call_args + torch.testing.assert_close(kwargs["padding_mask"], mask.logical_not()) + + def test_mlp_wrapper_dense_path(self, config, moe_config, backend_config, device): + config.first_k_dense_replace = 1 + block = ( + Block(layer_idx=0, config=config, moe_config=moe_config, backend=backend_config) + .to(device) + .to(torch.bfloat16) + ) + x = torch.randn(2, 4, HIDDEN, device=device, dtype=torch.bfloat16) + out = block._mlp(x, padding_mask=None) + assert out.shape == x.shape + + def test_init_weights_invokes_subcomponents(self, config, moe_config, backend_config, device): + block = Block(layer_idx=1, config=config, moe_config=moe_config, backend=backend_config).to(device) + with patch.object(block.input_layernorm, "reset_parameters") as in_norm, \ + patch.object(block.post_attention_layernorm, "reset_parameters") as post_norm, \ + patch.object(block.self_attn, "init_weights") as attn_init, \ + patch.object(block.mlp, "init_weights") as mlp_init: + block.init_weights(buffer_device=device) + in_norm.assert_called_once() + post_norm.assert_called_once() + attn_init.assert_called_once() + mlp_init.assert_called_once() + + +# --------------------------------------------------------------------------- +# HYV3Model +# --------------------------------------------------------------------------- + + +class TestHYV3Model: + def test_construction_sets_components(self, config, backend_config): + model = HYV3Model(config, backend=backend_config) + assert len(model.layers) == config.num_hidden_layers + assert model.embed_tokens.num_embeddings == config.vocab_size + assert model.norm is not None + assert model.rotary_emb.head_dim == config.head_dim + assert isinstance(model.moe_config, MoEConfig) + + def test_dense_then_moe_layer_structure(self, config, backend_config): + config.first_k_dense_replace = 1 + config.num_hidden_layers = 3 + model = HYV3Model(config, backend=backend_config) + assert isinstance(model.layers["0"].mlp, MLP) + assert isinstance(model.layers["1"].mlp, MoE) + assert isinstance(model.layers["2"].mlp, MoE) + + def test_moe_config_inferred_from_config(self, config, backend_config): + model = HYV3Model(config, backend=backend_config) + mc = model.moe_config + assert mc.dim == config.hidden_size + assert mc.moe_inter_dim == config.moe_intermediate_size + assert mc.n_routed_experts == config.num_experts + assert mc.n_activated_experts == config.num_experts_per_tok + assert mc.n_shared_experts == config.num_shared_experts + assert mc.score_func == "sigmoid" + assert mc.expert_activation == "swiglu" + + def test_moe_overrides_take_effect(self, config, backend_config): + model = HYV3Model(config, backend=backend_config, moe_overrides={"score_func": "softmax", "route_scale": 2.0}) + assert model.moe_config.score_func == "softmax" + assert model.moe_config.route_scale == 2.0 + + def test_explicit_moe_config_passes_through(self, config, backend_config, moe_config): + model = HYV3Model(config, backend=backend_config, moe_config=moe_config) + assert model.moe_config is moe_config + + def test_explicit_moe_config_and_overrides_conflict(self, config, backend_config, moe_config): + with pytest.raises(ValueError, match="Cannot pass both"): + HYV3Model(config, backend=backend_config, moe_config=moe_config, moe_overrides={"score_func": "softmax"}) + + def test_forward_runs_all_layers(self, config, backend_config, device): + model = HYV3Model(config, backend=backend_config).to(device) + bsz, seq = 1, 4 + input_ids = torch.randint(0, config.vocab_size, (bsz, seq), device=device) + with patch.object(Block, "forward", side_effect=lambda x=None, **kw: x if x is not None else kw["x"]) as mock_block: + out = model(input_ids) + assert out.shape == (bsz, seq, HIDDEN) + assert mock_block.call_count == config.num_hidden_layers + + def test_forward_accepts_explicit_position_ids(self, config, backend_config, device): + model = HYV3Model(config, backend=backend_config).to(device) + input_ids = torch.randint(0, config.vocab_size, (1, 4), device=device) + position_ids = torch.arange(4, device=device).unsqueeze(0) + with patch.object(Block, "forward", side_effect=lambda x=None, **kw: x if x is not None else kw["x"]): + out = model(input_ids, position_ids=position_ids) + assert out.shape == (1, 4, HIDDEN) + + def test_init_weights_resets_layers_norm_embeddings(self, config, backend_config, device): + model = HYV3Model(config, backend=backend_config).to(device) + embed_before = model.embed_tokens.weight.detach().clone() + with patch.object(model.norm, "reset_parameters") as mock_norm, \ + patch.object(Block, "init_weights") as mock_layer_init: + model.init_weights(buffer_device=device) + mock_norm.assert_called_once() + assert mock_layer_init.call_count == config.num_hidden_layers + # Embedding weights are re-initialized. + assert not torch.equal(model.embed_tokens.weight.detach(), embed_before) + + +# --------------------------------------------------------------------------- +# HYV3ForCausalLM +# --------------------------------------------------------------------------- + + +class TestHYV3ForCausalLM: + def test_construction_attaches_model_and_lm_head(self, config, backend_config): + model = HYV3ForCausalLM(config, backend=backend_config) + assert isinstance(model.model, HYV3Model) + assert model.lm_head.weight.shape == (config.vocab_size, config.hidden_size) + assert model.config is config + + def test_state_dict_adapter_attached_when_enabled(self, config, backend_config): + backend_config.enable_hf_state_dict_adapter = True + model = HYV3ForCausalLM(config, backend=backend_config) + from nemo_automodel.components.models.hy_v3.state_dict_adapter import HYV3StateDictAdapter + + assert hasattr(model, "state_dict_adapter") + assert isinstance(model.state_dict_adapter, HYV3StateDictAdapter) + + def test_state_dict_adapter_not_attached_when_disabled(self, config, backend_config): + backend_config.enable_hf_state_dict_adapter = False + model = HYV3ForCausalLM(config, backend=backend_config) + assert not hasattr(model, "state_dict_adapter") + + def test_default_backend_built_when_omitted(self, config): + model = HYV3ForCausalLM(config) + assert isinstance(model.backend, BackendConfig) + + def test_get_set_input_embeddings(self, config, backend_config): + model = HYV3ForCausalLM(config, backend=backend_config) + new_emb = torch.nn.Embedding(config.vocab_size, config.hidden_size) + model.set_input_embeddings(new_emb) + assert model.get_input_embeddings() is new_emb + + def test_get_set_output_embeddings(self, config, backend_config): + model = HYV3ForCausalLM(config, backend=backend_config) + new_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + model.set_output_embeddings(new_head) + assert model.get_output_embeddings() is new_head + + def test_forward_returns_logits_shape(self, config, backend_config, device): + model = HYV3ForCausalLM(config, backend=backend_config).to(device) + bsz, seq = 2, 6 + input_ids = torch.randint(0, config.vocab_size, (bsz, seq), device=device) + fake_hidden = torch.randn(bsz, seq, config.hidden_size, device=device, dtype=torch.bfloat16) + with patch.object(model.model, "forward", return_value=fake_hidden): + logits = model(input_ids) + assert logits.shape == (bsz, seq, config.vocab_size) + + def test_initialize_weights_invokes_inner_init(self, config, backend_config, device): + model = HYV3ForCausalLM(config, backend=backend_config).to(device) + with patch.object(model.model, "init_weights") as mock_init: + model.initialize_weights(buffer_device=device, dtype=torch.float32) + mock_init.assert_called_once() + assert model.lm_head.weight.dtype == torch.float32 + + def test_update_moe_gate_bias_no_op_when_factor_zero(self, config, backend_config, device): + """gate_bias_update_factor defaults to 0.0; update_moe_gate_bias must NOT call + gate.update_bias() when the factor is zero (the bug fixed in 564ff4f2).""" + model = HYV3ForCausalLM(config, backend=backend_config).to(device) + for layer in model.model.layers.values(): + if isinstance(layer.mlp, MoE): + with patch.object(layer.mlp.gate, "update_bias") as mock: + model.update_moe_gate_bias() + mock.assert_not_called() + + def test_update_moe_gate_bias_calls_when_factor_positive(self, config, backend_config, device): + model = HYV3ForCausalLM( + config, backend=backend_config, moe_overrides={"gate_bias_update_factor": 1e-3} + ).to(device) + called = 0 + for layer in model.model.layers.values(): + if isinstance(layer.mlp, MoE): + with patch.object(layer.mlp.gate, "update_bias") as mock: + model.update_moe_gate_bias() + if mock.called: + called += 1 + assert called > 0 + + def test_from_config_classmethod_passes_through(self, config, backend_config): + model = HYV3ForCausalLM.from_config(config, backend=backend_config) + assert isinstance(model, HYV3ForCausalLM) + + def test_from_pretrained_resolves_config_then_delegates(self, config, backend_config): + with patch("transformers.AutoConfig.from_pretrained", return_value=config) as mock_acfg, \ + patch.object(HYV3ForCausalLM, "from_config", wraps=HYV3ForCausalLM.from_config) as mock_fc: + model = HYV3ForCausalLM.from_pretrained("tencent/Hy3-preview", backend=backend_config) + mock_acfg.assert_called_once() + mock_fc.assert_called_once() + assert isinstance(model, HYV3ForCausalLM) + + def test_modelclass_alias(self): + assert ModelClass is HYV3ForCausalLM + + +# --------------------------------------------------------------------------- +# Module-level export +# --------------------------------------------------------------------------- + + +class TestModuleExports: + def test_init_exports_hyv3_for_causal_lm(self): + from nemo_automodel.components.models.hy_v3 import HYV3ForCausalLM as exported + + assert exported is HYV3ForCausalLM + + def test_module_class_pointer(self): + from nemo_automodel.components.models.hy_v3 import model as mod + + assert mod.ModelClass is HYV3ForCausalLM diff --git a/tests/unit_tests/models/hy_v3/test_hy_v3_state_dict_adapter.py b/tests/unit_tests/models/hy_v3/test_hy_v3_state_dict_adapter.py new file mode 100644 index 000000000..675c890b3 --- /dev/null +++ b/tests/unit_tests/models/hy_v3/test_hy_v3_state_dict_adapter.py @@ -0,0 +1,536 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``HYV3StateDictAdapter``. + +Covers the four behaviors that distinguish HYV3 from the shared +``MoESplitExpertsStateDictMixin``: + + 1. HYV3-specific name renames (router.gate, expert_bias, shared_mlp.). + 2. Per-expert split / merge inherited from the mixin. + 3. MTP-layer filtering (drops keys for layer index >= num_hidden_layers). + 4. Round-trip integrity: ``from_hf(to_hf(x))`` recovers ``x``. +""" + +from unittest.mock import Mock + +import pytest +import torch + +from nemo_automodel.components.models.common import BackendConfig +from nemo_automodel.components.models.hy_v3.state_dict_adapter import ( + _HF_TO_NATIVE_RENAMES, + _NATIVE_TO_HF_RENAMES, + HYV3StateDictAdapter, +) +from nemo_automodel.components.moe.config import MoEConfig + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +N_EXPERTS = 4 +HIDDEN = 16 +MOE_INTER = 8 +NUM_LAYERS = 2 # layer 0 dense, layer 1 MoE +NUM_MTP = 1 + + +@pytest.fixture +def config(): + cfg = Mock() + cfg.num_hidden_layers = NUM_LAYERS + cfg.hidden_size = HIDDEN + cfg.intermediate_size = 32 + cfg.moe_intermediate_size = MOE_INTER + cfg.num_attention_heads = 4 + cfg.num_key_value_heads = 2 + cfg.num_experts = N_EXPERTS + cfg.num_experts_per_tok = 2 + cfg.num_shared_experts = 1 + cfg.first_k_dense_replace = 1 + cfg.num_nextn_predict_layers = NUM_MTP + return cfg + + +@pytest.fixture +def moe_config(): + return MoEConfig( + dim=HIDDEN, + inter_dim=32, + moe_inter_dim=MOE_INTER, + n_routed_experts=N_EXPERTS, + n_shared_experts=1, + n_activated_experts=2, + n_expert_groups=0, + n_limited_groups=0, + train_gate=True, + gate_bias_update_factor=0.0, + score_func="sigmoid", + route_scale=1.0, + aux_loss_coeff=0.0, + norm_topk_prob=False, + expert_bias=False, + router_bias=False, + expert_activation="swiglu", + softmax_before_topk=False, + force_e_score_correction_bias=True, + ) + + +@pytest.fixture +def backend_config(): + return BackendConfig( + linear="torch", + attn="sdpa", + rms_norm="torch", + experts="torch", + dispatcher="torch", + fake_balanced_gate=False, + enable_hf_state_dict_adapter=False, + ) + + +@pytest.fixture +def adapter(config, moe_config, backend_config): + return HYV3StateDictAdapter( + config=config, moe_config=moe_config, backend=backend_config, dtype=torch.float32 + ) + + +def _make_disk_state_dict(*, with_mtp: bool = True): + """Synthesize a state dict matching the on-disk Tencent Hy3-preview format. + + Layer 0: dense MLP. Layer 1: MoE with N_EXPERTS experts + shared MLP + + router gate + expert_bias. Optionally include one MTP layer at index + NUM_LAYERS that should be filtered out by ``from_hf``. + """ + sd: dict[str, torch.Tensor] = { + # Top-level (passes through unchanged on both directions). + "model.embed_tokens.weight": torch.randn(32, HIDDEN), + "model.norm.weight": torch.randn(HIDDEN), + "lm_head.weight": torch.randn(32, HIDDEN), + # Layer 0: dense MLP + attention. + "model.layers.0.input_layernorm.weight": torch.randn(HIDDEN), + "model.layers.0.post_attention_layernorm.weight": torch.randn(HIDDEN), + "model.layers.0.self_attn.q_proj.weight": torch.randn(HIDDEN, HIDDEN), + "model.layers.0.self_attn.k_proj.weight": torch.randn(HIDDEN // 2, HIDDEN), + "model.layers.0.self_attn.v_proj.weight": torch.randn(HIDDEN // 2, HIDDEN), + "model.layers.0.self_attn.o_proj.weight": torch.randn(HIDDEN, HIDDEN), + "model.layers.0.mlp.gate_proj.weight": torch.randn(32, HIDDEN), + "model.layers.0.mlp.up_proj.weight": torch.randn(32, HIDDEN), + "model.layers.0.mlp.down_proj.weight": torch.randn(HIDDEN, 32), + # Layer 1: MoE -- on-disk format with Tencent-internal names. + "model.layers.1.input_layernorm.weight": torch.randn(HIDDEN), + "model.layers.1.post_attention_layernorm.weight": torch.randn(HIDDEN), + "model.layers.1.self_attn.q_proj.weight": torch.randn(HIDDEN, HIDDEN), + "model.layers.1.self_attn.k_proj.weight": torch.randn(HIDDEN // 2, HIDDEN), + "model.layers.1.self_attn.v_proj.weight": torch.randn(HIDDEN // 2, HIDDEN), + "model.layers.1.self_attn.o_proj.weight": torch.randn(HIDDEN, HIDDEN), + "model.layers.1.mlp.router.gate.weight": torch.randn(N_EXPERTS, HIDDEN), + "model.layers.1.mlp.expert_bias": torch.randn(N_EXPERTS), + "model.layers.1.mlp.shared_mlp.gate_proj.weight": torch.randn(MOE_INTER, HIDDEN), + "model.layers.1.mlp.shared_mlp.up_proj.weight": torch.randn(MOE_INTER, HIDDEN), + "model.layers.1.mlp.shared_mlp.down_proj.weight": torch.randn(HIDDEN, MOE_INTER), + } + for e in range(N_EXPERTS): + sd[f"model.layers.1.mlp.experts.{e}.gate_proj.weight"] = torch.randn(MOE_INTER, HIDDEN) + sd[f"model.layers.1.mlp.experts.{e}.up_proj.weight"] = torch.randn(MOE_INTER, HIDDEN) + sd[f"model.layers.1.mlp.experts.{e}.down_proj.weight"] = torch.randn(HIDDEN, MOE_INTER) + if with_mtp: + # MTP layer: layer index NUM_LAYERS, must be dropped on from_hf. + sd[f"model.layers.{NUM_LAYERS}.input_layernorm.weight"] = torch.randn(HIDDEN) + sd[f"model.layers.{NUM_LAYERS}.mlp.expert_bias"] = torch.randn(N_EXPERTS) + return sd + + +def _make_native_state_dict(): + """Synthesize a state dict matching the Automodel native HYV3 format.""" + sd: dict[str, torch.Tensor] = { + "model.embed_tokens.weight": torch.randn(32, HIDDEN), + "model.norm.weight": torch.randn(HIDDEN), + "lm_head.weight": torch.randn(32, HIDDEN), + "model.layers.0.input_layernorm.weight": torch.randn(HIDDEN), + "model.layers.0.post_attention_layernorm.weight": torch.randn(HIDDEN), + "model.layers.0.self_attn.q_proj.weight": torch.randn(HIDDEN, HIDDEN), + "model.layers.0.self_attn.k_proj.weight": torch.randn(HIDDEN // 2, HIDDEN), + "model.layers.0.self_attn.v_proj.weight": torch.randn(HIDDEN // 2, HIDDEN), + "model.layers.0.self_attn.o_proj.weight": torch.randn(HIDDEN, HIDDEN), + "model.layers.0.mlp.gate_proj.weight": torch.randn(32, HIDDEN), + "model.layers.0.mlp.up_proj.weight": torch.randn(32, HIDDEN), + "model.layers.0.mlp.down_proj.weight": torch.randn(HIDDEN, 32), + "model.layers.1.input_layernorm.weight": torch.randn(HIDDEN), + "model.layers.1.post_attention_layernorm.weight": torch.randn(HIDDEN), + "model.layers.1.self_attn.q_proj.weight": torch.randn(HIDDEN, HIDDEN), + "model.layers.1.self_attn.k_proj.weight": torch.randn(HIDDEN // 2, HIDDEN), + "model.layers.1.self_attn.v_proj.weight": torch.randn(HIDDEN // 2, HIDDEN), + "model.layers.1.self_attn.v_proj.weight": torch.randn(HIDDEN // 2, HIDDEN), + "model.layers.1.self_attn.o_proj.weight": torch.randn(HIDDEN, HIDDEN), + # Native MoE keys (post-rename + post-merge): + "model.layers.1.mlp.gate.weight": torch.randn(N_EXPERTS, HIDDEN), + "model.layers.1.mlp.gate.e_score_correction_bias": torch.randn(N_EXPERTS), + "model.layers.1.mlp.experts.gate_and_up_projs": torch.randn(N_EXPERTS, HIDDEN, 2 * MOE_INTER), + "model.layers.1.mlp.experts.down_projs": torch.randn(N_EXPERTS, MOE_INTER, HIDDEN), + "model.layers.1.mlp.shared_experts.gate_proj.weight": torch.randn(MOE_INTER, HIDDEN), + "model.layers.1.mlp.shared_experts.up_proj.weight": torch.randn(MOE_INTER, HIDDEN), + "model.layers.1.mlp.shared_experts.down_proj.weight": torch.randn(HIDDEN, MOE_INTER), + } + return sd + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +class TestInitialization: + def test_attributes_set(self, config, moe_config, backend_config): + a = HYV3StateDictAdapter(config=config, moe_config=moe_config, backend=backend_config, dtype=torch.float16) + assert a.config is config + assert a.moe_config is moe_config + assert a.backend is backend_config + assert a.dtype == torch.float16 + assert a._uses_model_prefix is True + + def test_default_dtype_is_bfloat16(self, config, moe_config, backend_config): + a = HYV3StateDictAdapter(config=config, moe_config=moe_config, backend=backend_config) + assert a.dtype == torch.bfloat16 + + def test_inherits_mixin(self, adapter): + from nemo_automodel.components.moe.state_dict_mixin import MoESplitExpertsStateDictMixin + + assert isinstance(adapter, MoESplitExpertsStateDictMixin) + + +# --------------------------------------------------------------------------- +# Rename tables +# --------------------------------------------------------------------------- + + +class TestRenameTables: + """Sanity-check the static rename tables: each native pattern must round-trip.""" + + @pytest.mark.parametrize( + "native, hf", + [ + ("model.layers.5.mlp.gate.e_score_correction_bias", "model.layers.5.mlp.expert_bias"), + ("model.layers.5.mlp.gate.weight", "model.layers.5.mlp.router.gate.weight"), + ("model.layers.5.mlp.shared_experts.gate_proj.weight", "model.layers.5.mlp.shared_mlp.gate_proj.weight"), + ("model.layers.5.mlp.shared_experts.up_proj.weight", "model.layers.5.mlp.shared_mlp.up_proj.weight"), + ("model.layers.5.mlp.shared_experts.down_proj.weight", "model.layers.5.mlp.shared_mlp.down_proj.weight"), + ], + ) + def test_native_to_hf_round_trip(self, native, hf): + # Apply native->HF + nk = native + for pat, repl in _NATIVE_TO_HF_RENAMES: + nk, n = pat.subn(repl, nk) + if n: + break + assert nk == hf + + # Apply HF->native + hk = hf + for pat, repl in _HF_TO_NATIVE_RENAMES: + hk, n = pat.subn(repl, hk) + if n: + break + assert hk == native + + def test_unrelated_keys_pass_through(self): + """Renames must not touch attention, embed, lm_head, layernorm keys.""" + for k in ( + "model.embed_tokens.weight", + "lm_head.weight", + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.input_layernorm.weight", + "model.layers.0.mlp.gate_proj.weight", # dense MLP gate_proj must NOT match + "model.norm.weight", + ): + for tab in (_NATIVE_TO_HF_RENAMES, _HF_TO_NATIVE_RENAMES): + v = k + for pat, repl in tab: + v, n = pat.subn(repl, v) + if n: + break + assert v == k, f"{k} unexpectedly renamed to {v}" + + +# --------------------------------------------------------------------------- +# from_hf: on-disk -> native +# --------------------------------------------------------------------------- + + +class TestFromHF: + def test_renames_router_gate(self, adapter): + hf = _make_disk_state_dict(with_mtp=False) + native = adapter.from_hf(hf, device_mesh=None) + assert "model.layers.1.mlp.gate.weight" in native + assert "model.layers.1.mlp.router.gate.weight" not in native + + def test_renames_expert_bias_to_gate_bias(self, adapter): + hf = _make_disk_state_dict(with_mtp=False) + native = adapter.from_hf(hf, device_mesh=None) + assert "model.layers.1.mlp.gate.e_score_correction_bias" in native + assert "model.layers.1.mlp.expert_bias" not in native + + def test_renames_shared_mlp_to_shared_experts(self, adapter): + hf = _make_disk_state_dict(with_mtp=False) + native = adapter.from_hf(hf, device_mesh=None) + for proj in ("gate_proj", "up_proj", "down_proj"): + assert f"model.layers.1.mlp.shared_experts.{proj}.weight" in native + assert f"model.layers.1.mlp.shared_mlp.{proj}.weight" not in native + + def test_merges_experts_into_grouped_form(self, adapter): + hf = _make_disk_state_dict(with_mtp=False) + native = adapter.from_hf(hf, device_mesh=None) + # Per-expert split keys must be gone. + for e in range(N_EXPERTS): + for proj in ("gate_proj", "up_proj", "down_proj"): + assert f"model.layers.1.mlp.experts.{e}.{proj}.weight" not in native + # Grouped tensors present. + assert "model.layers.1.mlp.experts.gate_and_up_projs" in native + assert "model.layers.1.mlp.experts.down_projs" in native + + def test_merged_shapes_are_native_layout(self, adapter): + hf = _make_disk_state_dict(with_mtp=False) + native = adapter.from_hf(hf, device_mesh=None) + # Native gate_and_up_projs: [E, hidden, 2*moe_inter] + assert tuple(native["model.layers.1.mlp.experts.gate_and_up_projs"].shape) == ( + N_EXPERTS, + HIDDEN, + 2 * MOE_INTER, + ) + # Native down_projs: [E, moe_inter, hidden] + assert tuple(native["model.layers.1.mlp.experts.down_projs"].shape) == ( + N_EXPERTS, + MOE_INTER, + HIDDEN, + ) + + def test_merged_values_match_per_expert_inputs(self, adapter): + """The stacked native tensors must contain the per-expert tensors transposed + and concatenated in the well-defined gate-then-up order.""" + hf = _make_disk_state_dict(with_mtp=False) + native = adapter.from_hf(hf, device_mesh=None) + + gate_up = native["model.layers.1.mlp.experts.gate_and_up_projs"] + down = native["model.layers.1.mlp.experts.down_projs"] + for e in range(N_EXPERTS): + g_hf = hf[f"model.layers.1.mlp.experts.{e}.gate_proj.weight"] # [moe_inter, hidden] + u_hf = hf[f"model.layers.1.mlp.experts.{e}.up_proj.weight"] + d_hf = hf[f"model.layers.1.mlp.experts.{e}.down_proj.weight"] # [hidden, moe_inter] + + # gate half: [hidden, moe_inter] -> first MOE_INTER columns of gate_and_up_projs + assert torch.allclose(gate_up[e, :, :MOE_INTER], g_hf.transpose(0, 1).to(adapter.dtype)) + # up half: last MOE_INTER columns + assert torch.allclose(gate_up[e, :, MOE_INTER:], u_hf.transpose(0, 1).to(adapter.dtype)) + # down: [hidden, moe_inter] HF -> [moe_inter, hidden] native + assert torch.allclose(down[e], d_hf.transpose(0, 1).to(adapter.dtype)) + + def test_drops_mtp_layer_keys(self, adapter): + hf = _make_disk_state_dict(with_mtp=True) + # MTP keys are present in input. + assert any(k.startswith(f"model.layers.{NUM_LAYERS}.") for k in hf) + native = adapter.from_hf(hf, device_mesh=None) + # MTP keys must not survive. + assert not any(k.startswith(f"model.layers.{NUM_LAYERS}.") for k in native) + assert not any(k.startswith(f"layers.{NUM_LAYERS}.") for k in native) + + def test_passes_through_unrelated_keys(self, adapter): + hf = _make_disk_state_dict(with_mtp=False) + native = adapter.from_hf(hf, device_mesh=None) + for k in ( + "model.embed_tokens.weight", + "model.norm.weight", + "lm_head.weight", + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.input_layernorm.weight", + "model.layers.0.mlp.gate_proj.weight", + "model.layers.1.input_layernorm.weight", + "model.layers.1.self_attn.q_proj.weight", + ): + assert k in native + assert torch.equal(native[k], hf[k]) + + +# --------------------------------------------------------------------------- +# to_hf: native -> on-disk +# --------------------------------------------------------------------------- + + +class TestToHF: + def test_renames_native_back_to_on_disk(self, adapter): + native = _make_native_state_dict() + hf = adapter.to_hf(native) + # On-disk keys present. + assert "model.layers.1.mlp.router.gate.weight" in hf + assert "model.layers.1.mlp.expert_bias" in hf + assert "model.layers.1.mlp.shared_mlp.gate_proj.weight" in hf + # Native keys gone. + assert "model.layers.1.mlp.gate.weight" not in hf + assert "model.layers.1.mlp.gate.e_score_correction_bias" not in hf + assert "model.layers.1.mlp.shared_experts.gate_proj.weight" not in hf + + def test_splits_experts_into_per_expert_keys(self, adapter): + native = _make_native_state_dict() + hf = adapter.to_hf(native) + # Grouped keys gone. + assert "model.layers.1.mlp.experts.gate_and_up_projs" not in hf + assert "model.layers.1.mlp.experts.down_projs" not in hf + # Per-expert split keys present. + for e in range(N_EXPERTS): + for proj in ("gate_proj", "up_proj", "down_proj"): + assert f"model.layers.1.mlp.experts.{e}.{proj}.weight" in hf + + def test_per_expert_shapes_match_disk_layout(self, adapter): + native = _make_native_state_dict() + hf = adapter.to_hf(native) + for e in range(N_EXPERTS): + assert tuple(hf[f"model.layers.1.mlp.experts.{e}.gate_proj.weight"].shape) == (MOE_INTER, HIDDEN) + assert tuple(hf[f"model.layers.1.mlp.experts.{e}.up_proj.weight"].shape) == (MOE_INTER, HIDDEN) + assert tuple(hf[f"model.layers.1.mlp.experts.{e}.down_proj.weight"].shape) == (HIDDEN, MOE_INTER) + + def test_exclude_key_regex(self, adapter): + native = _make_native_state_dict() + native["custom.exclude.weight"] = torch.randn(2, 2) + hf = adapter.to_hf(native, exclude_key_regex=r"^custom\.exclude") + assert "custom.exclude.weight" not in hf + # Renames still applied. + assert "model.layers.1.mlp.router.gate.weight" in hf + + +# --------------------------------------------------------------------------- +# convert_single_tensor_to_hf +# --------------------------------------------------------------------------- + + +class TestConvertSingleTensorToHF: + def test_non_expert_key_renamed(self, adapter): + t = torch.randn(N_EXPERTS, HIDDEN) + out = adapter.convert_single_tensor_to_hf("model.layers.1.mlp.gate.weight", t) + assert len(out) == 1 + assert out[0][0] == "model.layers.1.mlp.router.gate.weight" + assert torch.equal(out[0][1], t) + + def test_non_expert_key_pass_through(self, adapter): + t = torch.randn(HIDDEN, HIDDEN) + out = adapter.convert_single_tensor_to_hf("model.layers.0.self_attn.q_proj.weight", t) + assert len(out) == 1 + assert out[0][0] == "model.layers.0.self_attn.q_proj.weight" + assert torch.equal(out[0][1], t) + + def test_expert_tensor_split_and_renamed(self, adapter): + # gate_and_up_projs -> per-expert gate_proj.weight + up_proj.weight (split + transposed) + t = torch.randn(N_EXPERTS, HIDDEN, 2 * MOE_INTER) + pairs = adapter.convert_single_tensor_to_hf("model.layers.1.mlp.experts.gate_and_up_projs", t) + keys = {k for k, _ in pairs} + assert len(pairs) == 2 * N_EXPERTS + for e in range(N_EXPERTS): + assert f"model.layers.1.mlp.experts.{e}.gate_proj.weight" in keys + assert f"model.layers.1.mlp.experts.{e}.up_proj.weight" in keys + + def test_expert_down_proj_split_and_transposed(self, adapter): + t = torch.randn(N_EXPERTS, MOE_INTER, HIDDEN) + pairs = adapter.convert_single_tensor_to_hf("model.layers.1.mlp.experts.down_projs", t) + keys = {k for k, _ in pairs} + assert len(pairs) == N_EXPERTS + for e in range(N_EXPERTS): + assert f"model.layers.1.mlp.experts.{e}.down_proj.weight" in keys + # Shape per expert: [hidden, moe_inter] + for k, v in pairs: + assert tuple(v.shape) == (HIDDEN, MOE_INTER) + + def test_exclude_regex_applied_after_rename(self, adapter): + t = torch.randn(N_EXPERTS) + out = adapter.convert_single_tensor_to_hf( + "model.layers.1.mlp.gate.e_score_correction_bias", + t, + exclude_key_regex=r".*\.expert_bias$", # matches the renamed-to name + ) + assert out == [] + + def test_exclude_regex_when_no_match(self, adapter): + t = torch.randn(2, 2) + out = adapter.convert_single_tensor_to_hf("custom.weight", t, exclude_key_regex=r"^never") + assert len(out) == 1 + assert out[0][0] == "custom.weight" + + +# --------------------------------------------------------------------------- +# Round-trip: from_hf(to_hf(native)) == native +# --------------------------------------------------------------------------- + + +class TestRoundTrip: + def test_native_to_hf_to_native(self, adapter): + native = _make_native_state_dict() + hf = adapter.to_hf(native) + recovered = adapter.from_hf(hf, device_mesh=None) + + assert set(recovered.keys()) == set(native.keys()) + for k in native: + a, b = native[k].float(), recovered[k].float() + assert a.shape == b.shape, f"{k}: shape {a.shape} != {b.shape}" + assert torch.allclose(a, b, atol=1e-5, rtol=1e-5), f"{k} differs after round-trip" + + def test_disk_to_native_to_disk(self, adapter): + """Loading from disk then re-saving must round-trip every per-expert key. + MTP keys (which from_hf drops) are excluded from the comparison.""" + hf = _make_disk_state_dict(with_mtp=True) + native = adapter.from_hf(hf, device_mesh=None) + re_hf = adapter.to_hf(native) + + expected_keys = {k for k in hf if not adapter._is_mtp_key(k)} + # MTP layer keys must not show up in re_hf either. + assert set(re_hf.keys()) == expected_keys + + for k in expected_keys: + a, b = hf[k].float(), re_hf[k].float() + assert a.shape == b.shape, f"{k}: shape mismatch" + assert torch.allclose(a, b, atol=1e-5, rtol=1e-5), f"{k} differs after round-trip" + + +# --------------------------------------------------------------------------- +# _is_mtp_key +# --------------------------------------------------------------------------- + + +class TestIsMTPKey: + @pytest.mark.parametrize( + "key, expected", + [ + ("model.layers.0.self_attn.q_proj.weight", False), + ("model.layers.1.mlp.expert_bias", False), + (f"model.layers.{NUM_LAYERS}.input_layernorm.weight", True), + (f"model.layers.{NUM_LAYERS}.mlp.expert_bias", True), + (f"model.layers.{NUM_LAYERS + 5}.self_attn.q_proj.weight", True), + ("model.embed_tokens.weight", False), + ("model.norm.weight", False), + ("lm_head.weight", False), + # Without the model. prefix + ("layers.0.foo", False), + (f"layers.{NUM_LAYERS}.foo", True), + ], + ) + def test_layer_index_classification(self, adapter, key, expected): + assert adapter._is_mtp_key(key) is expected + + def test_threshold_uses_config_num_hidden_layers(self, config, moe_config, backend_config): + config.num_hidden_layers = 80 + a = HYV3StateDictAdapter(config=config, moe_config=moe_config, backend=backend_config) + assert a._is_mtp_key("model.layers.79.foo") is False + assert a._is_mtp_key("model.layers.80.foo") is True