From 2e6fd3c3d627f558e87237ba4e50aa7f1da9e353 Mon Sep 17 00:00:00 2001 From: Adil <47084919+adil-a@users.noreply.github.com> Date: Wed, 22 Apr 2026 22:30:30 -0400 Subject: [PATCH] fix: batch Flash 1B + Super-49B PEFT + qwen2.5-7B ckpt-robustness (#1984) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(rotary): install Nemotron-Flash NTK inv_freq and match native forward ``fix_rotary_embeddings`` used to unconditionally overwrite ``inv_freq`` with a vanilla-RoPE formula (no rope_type handling) and swap the forward with a vanilla variant. For Nemotron-Flash-1B — whose config declares ``rope_type: ntk`` and whose native rotary uses a non-standard NTK formula (``factor=2``, reads ``config.orig_max_position_embeddings``, no post-hoc ``attention_scaling``) — that silently downgraded training-time rope to vanilla. Since Phase 4 (vanilla ``AutoModelForCausalLM.from_pretrained``) uses Flash's native NTK rotary, training and Phase-4 logits diverged wildly and Phase 4 KL exceeded 1.0 (the reason #1973 had to skip Phase 4). Install ``inv_freq`` using Flash's own NTK formula (copied verbatim from ``modeling_nemotron_flash.LlamaRotaryEmbedding``) so training matches what vanilla HF computes on reload. Also update ``_safe_rope_forward`` to mirror Flash's native forward (``@torch.no_grad`` + autocast disable for FP32 rotary precision) so that the patched forward is semantically identical to letting the native forward run. Scope is narrowed to ``_is_nemotron_flash_config`` (unchanged from before); no other model family is affected. Signed-off-by: adil-a * fix(ckpt): preserve _tied_weights_keys dict so HF re-ties on reload ``apply_cache_compatibility_patches`` installs a patched ``post_init`` that converts the legacy list form of ``_tied_weights_keys`` into a dict and — crucially — set ``self._tied_weights_keys = {}`` to defer tying until after ``_model_init``. This breaks HF's own ``tie_weights()`` on downstream vanilla ``AutoModelForCausalLM.from_pretrained``: tie-key metadata is gone, so ``lm_head.weight`` is left at its zero init for tied-embedding models. Nemotron-Flash-1B's forward does ``logits / self.lm_head.weight.norm(p=2, dim=1)``, and dividing by a zero-vector norm yields NaN — observable only at Phase 4 of the checkpoint-robustness test. Keep the dict form on the model instead of clearing it: NeMo's own tying logic uses ``_nemo_tied_weights_keys`` and is unaffected, while HF's load path now sees a non-empty ``_tied_weights_keys`` and re-ties ``lm_head.weight`` -> ``embed_tokens.weight`` at reload time. Ports the key change from #1945. Signed-off-by: adil-a * test(ckpt-robustness): apply fix_rotary_embeddings in Phase 4 HF load ``fix_rotary_embeddings`` only runs through Automodel's ``_apply_runtime_compatibility_fixes`` hook during Automodel model setup (training + Phase 3 reload). Phase 4 uses vanilla ``AutoModelForCausalLM.from_pretrained`` directly, so Flash's native ``LlamaRotaryEmbedding.__init__`` runs unpatched and (even inside ``no_hf_meta_device``) produces garbage ``inv_freq`` values in the ~1e-26 range — effectively zero. That produces large Phase 4 KL even after the rotary + tied-weights fixes land on the Automodel side. Call ``fix_rotary_embeddings`` on the HF-loaded model (both the consolidated-dir load and the PEFT base-model load) when ``trust_remote_code=True``, so Phase 4 uses the same NTK-correct rotary as training. Scope is already narrowed to Nemotron-Flash via ``should_fix_rotary_embeddings``. Signed-off-by: adil-a * test(ckpt-robustness): re-enable Phase 4 for Nemotron-Flash-1B #1973 introduced ``skip_hf_reload: true`` for both Nemotron-Flash-1B recipes because vanilla HF reload was producing NaN logits / KL > 1.0. Root causes (fixed in prior commits): - Training rope was silently downgraded from NTK to vanilla by the old ``fix_rotary_embeddings`` patch (``_transformers/v4_patches/rotary.py``). - ``_tied_weights_keys`` was cleared at post_init, breaking HF's ``tie_weights()`` on reload so ``lm_head.weight`` stayed zero — and Flash's forward ``logits / lm_head.weight.norm()`` then NaN'd. - Native Flash rotary init produces garbage ``inv_freq`` under HF load; the test harness now re-applies ``fix_rotary_embeddings`` at Phase 4. With all three fixes, Phase 4 KL drops to: - SFT: 0.000e+00 (bit-exact vs training) - PEFT: 1.951e-03 (well under the 5e-3 default threshold) Remove ``skip_hf_reload: true`` so Phase 4 actually exercises the vanilla HF reload path again. Keep ``trust_remote_code: true`` (still required) and ``kl_threshold: 5e-3`` (PEFT Phase 3 ULP drift under TP=2 bf16 all-reduce). Signed-off-by: adil-a * refactor(rotary): drop redundant per-module Flash filter in fix_rotary_embeddings Match main's structure: rely solely on the external ``should_fix_rotary_embeddings`` gate at the call site (``infrastructure.py``, test harness) to keep Flash-only scope. The inner ``_is_nemotron_flash_config(cfg)`` check was defensive belt-and-suspenders against hypothetical misuse, but for all current call sites the outer gate already guarantees only Flash model trees reach this function, and within a Flash model tree every rotary module's ``config`` is the same Flash config. Dropping it keeps the diff vs main minimal. Signed-off-by: adil-a * fix(tests): recompute nemotron-nas rotary buffers in HF phase of checkpoint robustness Phase 4 of test_checkpoint_robustness_llm.py reloads the trained model via plain transformers.AutoModelForCausalLM and compares logits against the training reference. For model_type "nemotron-nas" (and "gemma3"), rotary inv_freq is a non-persistent buffer computed in __init__ and not written to safetensors. transformers 5.x defaults to meta-device init, so the computation produces meta tensors; when later materialized to GPU they contain uninitialized memory (values on the order of 1e30+ or zeros). Attention then rotates Q/K by garbage frequencies, diverging the HF reload from the training reference layer-by-layer. nemo-automodel's own loader avoids this by calling _reinit_non_persistent_buffers in apply_model_infrastructure, which is allow-listed for "nemotron-nas" and "gemma3". The robustness test's HF path did not run that reinit, so the comparison was measuring a broken HF model. This patch calls the same reinit helper after every HF from_pretrained site in Phase 4 (PEFT and SFT paths, both hf_device_map_auto branches) via a small wrapper that resolves each module's own device so it works correctly under device_map="auto" where modules can live on different GPUs. Verified on nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 with the existing robustness launch command from scripts/finetune_launcher.sh: [Phase 4] HF-loaded max KL: 9.17e-04 (threshold: 5.00e-03) PASS Prior to the fix Phase 4 produced max KL ~1.05e+01 against the same reference (~11000x improvement), which is why the WIP branch for this recipe had been raising hf_kl_threshold to mask the loader bug. Signed-off-by: adil-a * ci(yaml): bump dist timeout to 20min, set resume_loss_threshold=5e-2 for 49B squad peft Hold-overs from the superseded PR #1951 that are independent of the rotary reinit fix: - timeout_minutes 1 -> 20: Phase 4 rank-0 HF load of the 49B base under device_map="auto" can take several minutes; the 1-minute default occasionally trips the NCCL init barrier. - resume_loss_threshold 5e-2: Phase 6 fresh-train vs resume-from-checkpoint loss tolerance. Matches the empirical step-to-step resume diff observed on the 49B PEFT run (~1.7e-02 .. 3.0e-02). hf_kl_threshold remains at the standard 5e-3; the previous bump to 1.5e1 in #1951 was masking the rotary inv_freq bug now fixed in the preceding commit. Signed-off-by: adil-a * fix: qwen2_5_7b_squad ckpt robustness thresholds for transformers v5.5 - Bump `ci.checkpoint_robustness.hf_kl_threshold` from 9e-3 to 2.5e-2 to tolerate the Phase 4 (vanilla HF forward) numerical drift introduced by the transformers v5.5 upgrade (#1734), matching the precedent set by #1867 (qwen3_moe, gpt_oss) and #1932 (gemma_3_270m_squad). - Add `ci.checkpoint_robustness.resume_loss_threshold: 5e-2` to tolerate the Phase 6 (resume vs continuous-baseline) loss drift observed at TP=2 for this model, following the existing Baichuan 2 7B precedent (examples/llm_finetune/baichuan/baichuan_2_7b_squad.yaml uses the same 5e-2 value for the same check). Phase 3 KL stays at 0 — save/reload is bit-exact — so this is not a checkpoint correctness bug; it is forward-pass + TP=2 bf16 accumulation drift that the pre-v5.5 thresholds no longer accommodate. Signed-off-by: Adil Asif Signed-off-by: adil-a * fix(qwen2_5_7b_squad): unify hf_kl_threshold to 1e-1 Matches the policy from batch PR #1971 (closed): unify ``hf_kl_threshold`` at 1e-1 for all pipeline 48953745 recipes that were bumping it from a lower default. Author's re-verification (separate env) confirmed the value exercised works; going to 1e-1 keeps this recipe consistent with the pipeline-wide bound. Signed-off-by: adil-a * fix(49B SFT): add trust_remote_code to ckpt-robustness config Mirror the #1981 PEFT YAML change. Without ``trust_remote_code: true`` the Phase 4 HF load cannot find the ``nemotron-nas`` (DeciLM) class (it lives in remote code under trust_remote_code, not transformers itself) and fails with ``Unrecognized model in .../consolidated``. Pairs with the existing ``_reinit_rotary_per_module`` patch from #1981 which handles nemotron-nas' non-persistent rotary ``inv_freq`` buffer at Phase 4 HF load time. Signed-off-by: adil-a * fix(ckpt): write full config dict to consolidated config.json (use_diff=False) ``ConsolidatedHFAddon.pre_save`` wrote ``config.json`` via the default ``to_json_string(use_diff=True)`` path, which internally calls ``to_diff_dict()`` and emits only fields whose values differ from the class defaults. For remote-code configs registered via ``register_for_auto_class`` (e.g. DeciLM ``model_type="nemotron-nas"`` for Llama-3.3-Nemotron-Super-49B), the class-level ``model_type`` attribute compares equal to the class-default value and is silently dropped from the serialized JSON. Reloading the consolidated dir via ``AutoConfig.from_pretrained`` then fails with ``Unrecognized model in .../consolidated. Should have a 'model_type' key in its config.json``. Switch to ``use_diff=False`` so the full ``to_dict()`` output is serialized. ``model_type``, ``architectures`` and ``auto_map`` are now always present in the saved config. Slightly larger config.json (extra defaulted fields appear) but no behavioural change for standard HF models that were already serializing correctly. Supersedes the dead ``_ensure_model_type_and_auto_map`` helper from the abandoned #1950 iteration. Signed-off-by: adil-a * fix(49B SFT): bump dist_env timeout_minutes: 1 -> 20 Same fix as #1981 for the PEFT variant. On 2 nodes with TP=8 PP=2, rank 0 needs to ``deepcopy`` massive submodule trees in PP stage build (``_build_stage_from_modules``). For a 49B model this can take well over the default 60-second NCCL AllReduce timeout, so the other 15 ranks watchdog-terminate their collectives while rank 0 is still deepcopying. Raise the timeout to 20 minutes so PP stage split has room to complete. Signed-off-by: adil-a * fix(49B SFT): add resume_loss_threshold: 5e-2 (mirror PEFT) PEFT's YAML already sets ``ci.checkpoint_robustness.resume_loss_threshold: 5e-2`` (via the #1981 cherry-pick). Apply the same defense to SFT: on 2-node TP=8 PP=2 setups, Phase 6 resume-loss diff from grad-accum reduction ordering at 16-rank scale can plausibly exceed the default ``5e-3`` threshold, so relax to 5e-2 to avoid spurious Phase 6 failures. Not brought over from PEFT: ``check_fused_qkv_keys: true`` (PEFT adapter specific, no adapter saved in SFT). Signed-off-by: adil-a * debug(pipelining): instrument _build_stage_from_modules deepcopy timing Diagnostic-only commit to measure the PP-stage-build deepcopy for Super-49B. Logs at DEBUG/INFO: param device+dtype, total param count, and wall-clock elapsed for the copy.deepcopy(model) call. To be reverted after we characterise the bottleneck. Signed-off-by: adil-a * test: scope nightly recipes to nemotron_flash only (temporary) Temporary change to validate PR #1984's Flash 1B fixes; to be reverted before merge. * revert Signed-off-by: Alexandros Koumparoulis * lint Signed-off-by: Alexandros Koumparoulis * add test from @qiaochuz-nv Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * Revert "debug(pipelining): instrument _build_stage_from_modules deepcopy timing" This reverts the debug-only instrumentation from 1c5da815 (and the related lint adjustment in b1e8f239 for the same block). The diagnostic logging was intended to be reverted after characterising the PP-stage-build deepcopy bottleneck for Super-49B. The added list(model.parameters()) call also broke tests/unit_tests/distributed/pipelining/test_functional.py:: TestSplitModelIntoStages because the mocked model's parameters() returns a Mock, not an iterable. --------- Signed-off-by: adil-a Signed-off-by: Adil Asif Signed-off-by: Alexandros Koumparoulis Co-authored-by: Alexandros Koumparoulis Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Signed-off-by: NeMo Bot --- .../llama3_3_nemotron_super_49B_squad.yaml | 4 +- ...lama3_3_nemotron_super_49B_squad_peft.yaml | 3 +- .../nemotron_flash_1b_squad.yaml | 10 -- .../nemotron_flash_1b_squad_peft.yaml | 9 -- .../llm_finetune/qwen/qwen2_5_7b_squad.yaml | 3 +- nemo_automodel/_transformers/utils.py | 13 ++- .../_transformers/v4_patches/rotary.py | 109 ++++++++++++------ .../components/checkpoint/addons.py | 16 ++- .../test_checkpoint_robustness_llm.py | 59 ++++++++++ .../_transformers/test_transformers_utils.py | 37 +++++- 10 files changed, 203 insertions(+), 60 deletions(-) diff --git a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad.yaml b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad.yaml index 0593583c96..73a032c39c 100644 --- a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad.yaml +++ b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad.yaml @@ -29,7 +29,7 @@ step_scheduler: dist_env: backend: nccl - timeout_minutes: 1 + timeout_minutes: 20 rng: _target_: nemo_automodel.components.training.rng.StatefulRNG @@ -122,8 +122,10 @@ ci: vllm_deploy: true checkpoint_robustness: hf_kl_threshold: 5e-3 + resume_loss_threshold: 5e-2 distributed.tp_size: 8 tokenizer_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 hf_device_map_auto: true + trust_remote_code: true dataset.limit_dataset_samples: 500 validation_dataset.limit_dataset_samples: 500 diff --git a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml index feb2828dda..359686eee4 100644 --- a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml +++ b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml @@ -29,7 +29,7 @@ step_scheduler: dist_env: backend: nccl - timeout_minutes: 1 + timeout_minutes: 20 rng: _target_: nemo_automodel.components.training.rng.StatefulRNG @@ -113,6 +113,7 @@ ci: recipe_owner: HuiyingLi checkpoint_robustness: hf_kl_threshold: 5e-3 + resume_loss_threshold: 5e-2 trust_remote_code: true distributed.tp_size: 2 tokenizer_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 diff --git a/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad.yaml b/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad.yaml index 804f79fb48..7f7fe01522 100755 --- a/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad.yaml +++ b/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad.yaml @@ -108,16 +108,6 @@ ci: hf_kl_threshold: 5e-3 tokenizer_name: nvidia/Nemotron-Flash-1B trust_remote_code: true - # Skip Phase 4 (vanilla `AutoModelForCausalLM.from_pretrained` reload). - # Nemotron-Flash's trust_remote_code modeling code (custom - # rotary / memory tokens / fused_mha + flash_attention_2 dispatch via - # `attn_implementation_new`) doesn't round-trip cleanly through vanilla - # HF init under transformers 5.x and produces NaN logits on first - # forward even with HF-modules-cache pre-seeding and `no_hf_meta_device`. - # Phase 3 (Automodel-from-consolidated, KL ≈ 0) and the separate - # vllm_deploy job already prove the consolidated checkpoint is - # loadable; Phase 4 adds no incremental signal for this model. - skip_hf_reload: true # FIXME(akoumpa): pragmatic relaxation, not a verified fix. The measured # baseline-vs-resume loss drift on 4-GPU ptyche is ~1.05e-2 at step 5 (seen # in CI job 302796035), well above the default 5e-3. Nemotron-Flash's diff --git a/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml b/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml index 579a3abadd..c62c5abcb9 100755 --- a/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml +++ b/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml @@ -120,15 +120,6 @@ ci: distributed.tp_size: 2 tokenizer_name: nvidia/Nemotron-Flash-1B trust_remote_code: true - # Skip Phase 4 (vanilla HF reload of base model + PEFT adapter). - # `NemotronFlashModel.__init__` silently clobbers our - # `attn_implementation="flash_attention_2"` override back to the hub - # default `fused_mha` via `config.attn_implementation_new`, and that - # path plus the custom rotary / memory-token init under transformers - # 5.x's meta-device context produces NaN logits on first forward. - # Phase 3 (Automodel + PEFT from consolidated) already validates - # correctness to KL ≈ 2.7e-3, well under threshold. - skip_hf_reload: true check_fused_qkv_keys: true dataset.limit_dataset_samples: 500 validation_dataset.limit_dataset_samples: 500 diff --git a/examples/llm_finetune/qwen/qwen2_5_7b_squad.yaml b/examples/llm_finetune/qwen/qwen2_5_7b_squad.yaml index dda0e0cec2..dff224a84f 100644 --- a/examples/llm_finetune/qwen/qwen2_5_7b_squad.yaml +++ b/examples/llm_finetune/qwen/qwen2_5_7b_squad.yaml @@ -105,10 +105,11 @@ ci: vllm_deploy: true recipe_owner: HuiyingLi checkpoint_robustness: - hf_kl_threshold: 9e-3 + hf_kl_threshold: 1e-1 distributed.tp_size: 2 tokenizer_name: Qwen/Qwen2.5-7B cross_tp_size: 2 cross_tp_kl_threshold: 9e-3 + resume_loss_threshold: 5e-2 dataset.limit_dataset_samples: 500 validation_dataset.limit_dataset_samples: 500 diff --git a/nemo_automodel/_transformers/utils.py b/nemo_automodel/_transformers/utils.py index 06e303c1c6..14b2cd2434 100644 --- a/nemo_automodel/_transformers/utils.py +++ b/nemo_automodel/_transformers/utils.py @@ -225,8 +225,17 @@ def _patched_post_init(self): source = _find_embedding_source(self) if source is None: raise ValueError("Could not find the source of the embedding layer") - self._nemo_tied_weights_keys = {k: source for k in tied} - self._tied_weights_keys = {} + tied_dict = {k: source for k in tied} + self._nemo_tied_weights_keys = tied_dict + # Keep the v5 dict form on the model so that any downstream HF + # code path (e.g. vanilla ``AutoModelForCausalLM.from_pretrained`` + # used by the checkpoint-robustness test) ties the weights via + # HF's own ``tie_weights`` and does not leave the tied sibling + # (``lm_head.weight``) zero-initialised — which would cause + # NaN logits for tied-embedding remote-code models like + # Nemotron-Flash-1B whose forward does + # ``logits / lm_head.weight.norm()``. + self._tied_weights_keys = tied_dict # call orig post init _orig_post_init(self) diff --git a/nemo_automodel/_transformers/v4_patches/rotary.py b/nemo_automodel/_transformers/v4_patches/rotary.py index c1f526b5ab..6beda0c297 100644 --- a/nemo_automodel/_transformers/v4_patches/rotary.py +++ b/nemo_automodel/_transformers/v4_patches/rotary.py @@ -29,18 +29,52 @@ def _to_local(t): return t._local_tensor if isinstance(t, DTensor) else t +@torch.no_grad() def _safe_rope_forward(self, x, position_ids, **kwargs): - """Drop-in replacement for legacy rotary embedding forward methods.""" - inv_freq = self.inv_freq.float() - inv_freq_expanded = inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1) + """Drop-in replacement matching Nemotron-Flash-1B's native rotary forward. + + Mirrors ``modeling_nemotron_flash.LlamaRotaryEmbedding.forward`` verbatim + (incl. ``@torch.no_grad`` + autocast disable for FP32 precision) so that + running this patched forward is semantically identical to letting Flash's + native forward run with the same ``inv_freq``. + """ + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() + + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +def _compute_flash_inv_freq(cfg, device, dim): + """Compute ``inv_freq`` using Nemotron-Flash-1B's own NTK/default formula. + + Copy of the relevant init branch from + ``modeling_nemotron_flash.LlamaRotaryEmbedding.__init__``. Flash's NTK + differs from transformers' standard: + - ``factor = 2`` (hardcoded in Flash) + - Reads ``config.orig_max_position_embeddings`` (not + ``original_max_position_embeddings``). + - Scales ``base`` directly (no post-hoc ``attention_scaling``). + """ + base = float(getattr(cfg, "rope_theta", 10000.0) or 10000.0) + rope_type = getattr(cfg, "rope_type", None) or "default" + if rope_type == "ntk": + max_pos = getattr(cfg, "max_position_embeddings", None) + orig_max = getattr(cfg, "orig_max_position_embeddings", None) + if max_pos is not None and orig_max is not None and orig_max > 0: + factor = 2 + base = base * ((factor * max_pos / orig_max) - (factor - 1)) ** (dim / (dim - 2)) + # default / dynamic_ntk / unknown: use (possibly unscaled) ``base``. + indices = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() + return 1.0 / (base ** (indices / dim)) + + def _is_nemotron_flash_config(cfg): if cfg is None: return False @@ -71,7 +105,32 @@ def should_fix_rotary_embeddings(model_parts): def fix_rotary_embeddings(model_parts): - """Patch rotary embeddings to bypass fragile legacy HF runtime behavior.""" + """Install Nemotron-Flash-1B's native NTK ``inv_freq`` deterministically. + + Flash's own ``LlamaRotaryEmbedding.__init__`` (remote code, under + trust_remote_code) can land with NaN/Inf ``inv_freq`` buffers under + transformers 5.x's meta-device init context, and its NTK formula is + non-standard (``factor=2``, reads ``config.orig_max_position_embeddings``, + no post-hoc ``attention_scaling``), so transformers' own + ``ROPE_INIT_FUNCTIONS`` does not match it. The old version of this patch + sidestepped that by overwriting ``inv_freq`` with a plain-vanilla formula + (no NTK) and replacing ``forward`` with a vanilla one — but that silently + downgraded training-time rope semantics relative to Flash's native, which + vanilla HF uses when reloading the consolidated checkpoint. The result was + Phase 4 HF KL > 1.0, "fixed" by skipping Phase 4. + + This revised patch computes ``inv_freq`` using Flash's *own* NTK formula + (copied verbatim from ``modeling_nemotron_flash.LlamaRotaryEmbedding``) and + installs it on every Flash rotary found, unconditionally. The forward is + also replaced with ``_safe_rope_forward`` (now semantically identical to + Flash's native forward), which guards against any init-order oddity in + the remote-code class. Training, Phase 3 Automodel reload, and Phase 4 + vanilla HF reload all end up computing the same NTK-scaled rope. + + Scope: only touches modules whose ``config`` is recognized as Nemotron- + Flash (via ``_is_nemotron_flash_config``), so non-Flash models are never + affected. ``should_fix_rotary_embeddings`` further narrows the call site. + """ fixed = 0 for mp in model_parts: for fqn, module in mp.named_modules(): @@ -79,28 +138,13 @@ def fix_rotary_embeddings(model_parts): if inv is None or not isinstance(inv, torch.Tensor): continue - iv = _to_local(inv) - bad = bool(torch.isnan(iv).any().item()) or bool(torch.isinf(iv).any().item()) - cfg = getattr(module, "config", None) - if cfg is not None: - rope_params = getattr(cfg, "rope_parameters", {}) or {} - base = rope_params.get("rope_theta", getattr(cfg, "rope_theta", 10000.0)) - dim = getattr(cfg, "head_dim", None) - if dim is None: - hs = getattr(cfg, "hidden_size", None) - nah = getattr(cfg, "num_attention_heads", None) - if hs and nah: - dim = hs // nah - if dim is None: - dim = iv.shape[-1] * 2 - else: - base = 10000.0 - dim = iv.shape[-1] * 2 - - new_inv = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=iv.device, dtype=torch.float32) / dim) - ) + iv = _to_local(inv) + # ``inv_freq`` has ``dim/2`` elements → dim = 2 * iv.shape[-1]. + # Verified against ``LlamaRotaryEmbedding.__init__`` which uses + # ``torch.arange(0, dim, 2)`` of length ``dim/2``. + dim = iv.shape[-1] * 2 + new_inv = _compute_flash_inv_freq(cfg, iv.device, dim) inv.data.copy_(new_inv.to(dtype=inv.dtype, device=inv.device)) orig = getattr(module, "original_inv_freq", None) @@ -109,12 +153,11 @@ def fix_rotary_embeddings(model_parts): module.forward = types.MethodType(_safe_rope_forward, module) - logger.info( - f"[fix_rope] {fqn}: patched forward + inv_freq (bad={bad} base={base} dim={dim})", - ) + rope_type = getattr(cfg, "rope_type", None) or "default" + logger.info(f"[fix_rope] {fqn}: installed Flash NTK inv_freq (rope_type={rope_type}, dim={dim})") fixed += 1 - logger.info(f"[fix_rope] patched {fixed} rotary embeddings.") + logger.info(f"[fix_rope] repaired {fixed} rotary embeddings.") return fixed diff --git a/nemo_automodel/components/checkpoint/addons.py b/nemo_automodel/components/checkpoint/addons.py index cc7708e5de..f0bb8a4f45 100644 --- a/nemo_automodel/components/checkpoint/addons.py +++ b/nemo_automodel/components/checkpoint/addons.py @@ -77,7 +77,21 @@ def pre_save(self, **kwargs) -> None: _maybe_strip_quantization_config(model_part) with open(os.path.join(hf_metadata_dir, config_name), "w") as f: if hasattr(model_part.config, "to_json_string"): - f.write(model_part.config.to_json_string()) + # Use ``use_diff=False`` so the full config (not the + # diff against class defaults) is serialized. For + # remote-code configs registered via + # ``register_for_auto_class`` (e.g. DeciLM / + # Llama-Nemotron-Super-49B ``model_type='nemotron-nas'``), + # ``to_diff_dict`` sees the class-level ``model_type`` + # attribute as equal to the class default and drops + # it from the serialized JSON. Reloading via + # ``AutoConfig.from_pretrained`` on the resulting + # consolidated directory then raises + # ``Unrecognized model ... Should have a 'model_type' + # key``. Writing the full dict guarantees + # ``model_type``, ``architectures`` and ``auto_map`` + # land in the saved config regardless of class defaults. + f.write(model_part.config.to_json_string(use_diff=False)) else: # Diffusers models use FrozenDict for config instead of PretrainedConfig json.dump(dict(model_part.config), f, indent=2, default=str) diff --git a/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py b/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py index 24ec7694e0..7bfe53bd40 100644 --- a/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py +++ b/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py @@ -35,6 +35,10 @@ import torch.nn.functional as F from torch.distributed.tensor import DTensor +from nemo_automodel.components.checkpoint.checkpointing import ( + _MODELS_REQUIRING_BUFFER_REINIT, + _reinit_non_persistent_buffers, +) from nemo_automodel.components.config._arg_parser import parse_args_and_load_config from nemo_automodel.recipes.llm.train_ft import TrainFinetuneRecipeForNextTokenPrediction @@ -215,6 +219,31 @@ def _get_logits(model, input_ids, device, trainer=None) -> torch.Tensor: return logits.float().cpu() +def _reinit_rotary_per_module(model, default_device): + """Recompute DeciLM / Gemma3 style non-persistent rotary buffers on each + module's own device. + + HF `from_pretrained` in transformers 5.x leaves ``inv_freq`` uninitialized + for models whose rotary buffers are computed in ``__init__`` and never + saved to the state dict (e.g. nemotron-nas, gemma3). With + ``device_map='auto'`` each rotary module can live on a different GPU, so + we drive the recompute per-module using its own inv_freq device rather + than a single fixed device. + """ + model_type = getattr(model.config, "model_type", None) + if model_type not in _MODELS_REQUIRING_BUFFER_REINIT: + return model + for mod in model.modules(): + inv = getattr(mod, "inv_freq", None) + if inv is None: + continue + mod_device = inv.device + if mod_device.type == "meta": + mod_device = next((p.device for p in mod.parameters()), default_device) + _reinit_non_persistent_buffers(mod, mod_device, model_type=model_type) + return model + + def _fix_meta_rotary_embeddings(model): """Re-materialize RotaryEmbedding tensors stuck on meta device. @@ -500,6 +529,24 @@ def test_checkpoint_robustness(): base_model = _fix_meta_rotary_embeddings( AutoModelForCausalLM.from_pretrained(original_pretrained_path, **hf_kwargs) ).to(device) + # Re-init non-persistent rotary buffers for ``model_type`` values + # in ``_MODELS_REQUIRING_BUFFER_REINIT`` (``nemotron-nas``, + # ``gemma3``) — their ``inv_freq`` is computed in ``__init__`` and + # never written to the checkpoint; meta-device init leaves + # garbage values after ``from_pretrained``. + _reinit_rotary_per_module(base_model, device) + # For Nemotron-Flash (``model_type=="nemotron_flash"``) the + # ``inv_freq`` buffer also lands garbage under HF load but its + # NTK formula is non-standard, so route through the dedicated + # ``fix_rotary_embeddings`` patch which installs Flash's own NTK + # formula and mirrors Flash's native forward. + if trust_remote_code: + from nemo_automodel._transformers.v4_patches.rotary import ( + fix_rotary_embeddings, + should_fix_rotary_embeddings, + ) + if should_fix_rotary_embeddings([base_model]): + fix_rotary_embeddings([base_model]) peft_model = PeftModel.from_pretrained(base_model, str(ckpt_step_dir / "model")) hf_logits = _get_logits(peft_model, input_ids, device) @@ -528,6 +575,18 @@ def test_checkpoint_robustness(): hf_model = _fix_meta_rotary_embeddings( AutoModelForCausalLM.from_pretrained(str(consolidated_dir), **hf_kwargs) ).to(device) + # Re-init non-persistent rotary buffers for nemotron-nas / gemma3 + # (``_MODELS_REQUIRING_BUFFER_REINIT`` allow-list). See PEFT branch + # above for details. + _reinit_rotary_per_module(hf_model, device) + # For Nemotron-Flash: install NTK inv_freq via dedicated patch. + if trust_remote_code: + from nemo_automodel._transformers.v4_patches.rotary import ( + fix_rotary_embeddings, + should_fix_rotary_embeddings, + ) + if should_fix_rotary_embeddings([hf_model]): + fix_rotary_embeddings([hf_model]) hf_logits = _get_logits(hf_model, input_ids, device) del hf_model diff --git a/tests/unit_tests/_transformers/test_transformers_utils.py b/tests/unit_tests/_transformers/test_transformers_utils.py index 6b3c6ee4f6..3d4018b53b 100644 --- a/tests/unit_tests/_transformers/test_transformers_utils.py +++ b/tests/unit_tests/_transformers/test_transformers_utils.py @@ -389,7 +389,7 @@ def __init__(self, config): assert isinstance(tied, dict) assert "lm_head.weight" in tied assert tied["lm_head.weight"] == "model.embed_tokens.weight" - assert model._tied_weights_keys == {} + assert model._tied_weights_keys == tied def test_tied_weights_keys_patch_converts_any_model(self): """The post_init patch should convert _tied_weights_keys for any model, not just phi4mm.""" @@ -418,7 +418,40 @@ def __init__(self, config): assert isinstance(tied, dict) assert "lm_head.weight" in model._nemo_tied_weights_keys assert tied["lm_head.weight"] == "model.embed_tokens.weight" - assert model._tied_weights_keys == {} + assert model._tied_weights_keys == tied + + def test_tied_weights_keys_patch_resolves_top_level_embed_tokens(self): + """The post_init patch resolves embed_tokens at the top level via get_input_embeddings.""" + apply_cache_compatibility_patches() + import torch.nn as nn + from transformers import PretrainedConfig + from transformers.modeling_utils import PreTrainedModel + + class _MockConfig(PretrainedConfig): + model_type = "nemotron_flash" + + class _Model(PreTrainedModel): + config_class = _MockConfig + + def __init__(self, config): + super().__init__(config) + self.embed_tokens = nn.Embedding(100, 16) + self.lm_head = nn.Linear(16, 100, bias=False) + self._tied_weights_keys = ["lm_head.weight"] + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def forward(self, x): + return self.lm_head(self.embed_tokens(x)) + + model = _Model(_MockConfig()) + assert isinstance(model._tied_weights_keys, dict) + assert "lm_head.weight" in model._tied_weights_keys + assert model._tied_weights_keys["lm_head.weight"] == "embed_tokens.weight" + assert isinstance(model._nemo_tied_weights_keys, dict) + assert model._nemo_tied_weights_keys["lm_head.weight"] == "embed_tokens.weight" def test_patches_peft_prepare_inputs(self): """PeftModelForCausalLM.__init__ should be patched for missing prepare_inputs_for_generation."""