From aabb92aa3d05d4ea9248502e791fd1d03f909a38 Mon Sep 17 00:00:00 2001 From: adil-a Date: Tue, 21 Apr 2026 23:49:41 -0700 Subject: [PATCH 1/2] fix(tests): recompute nemotron-nas rotary buffers in HF phase of checkpoint robustness Phase 4 of test_checkpoint_robustness_llm.py reloads the trained model via plain transformers.AutoModelForCausalLM and compares logits against the training reference. For model_type "nemotron-nas" (and "gemma3"), rotary inv_freq is a non-persistent buffer computed in __init__ and not written to safetensors. transformers 5.x defaults to meta-device init, so the computation produces meta tensors; when later materialized to GPU they contain uninitialized memory (values on the order of 1e30+ or zeros). Attention then rotates Q/K by garbage frequencies, diverging the HF reload from the training reference layer-by-layer. nemo-automodel's own loader avoids this by calling _reinit_non_persistent_buffers in apply_model_infrastructure, which is allow-listed for "nemotron-nas" and "gemma3". The robustness test's HF path did not run that reinit, so the comparison was measuring a broken HF model. This patch calls the same reinit helper after every HF from_pretrained site in Phase 4 (PEFT and SFT paths, both hf_device_map_auto branches) via a small wrapper that resolves each module's own device so it works correctly under device_map="auto" where modules can live on different GPUs. Verified on nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 with the existing robustness launch command from scripts/finetune_launcher.sh: [Phase 4] HF-loaded max KL: 9.17e-04 (threshold: 5.00e-03) PASS Prior to the fix Phase 4 produced max KL ~1.05e+01 against the same reference (~11000x improvement), which is why the WIP branch for this recipe had been raising hf_kl_threshold to mask the loader bug. Signed-off-by: adil-a --- .../test_checkpoint_robustness_llm.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py b/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py index b2ca493f34..7dbeae938a 100644 --- a/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py +++ b/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py @@ -35,6 +35,10 @@ import torch.nn.functional as F from torch.distributed.tensor import DTensor +from nemo_automodel.components.checkpoint.checkpointing import ( + _MODELS_REQUIRING_BUFFER_REINIT, + _reinit_non_persistent_buffers, +) from nemo_automodel.components.config._arg_parser import parse_args_and_load_config from nemo_automodel.recipes.llm.train_ft import TrainFinetuneRecipeForNextTokenPrediction @@ -214,6 +218,31 @@ def _get_logits(model, input_ids, device, trainer=None) -> torch.Tensor: return logits.float().cpu() +def _reinit_rotary_per_module(model, default_device): + """Recompute DeciLM / Gemma3 style non-persistent rotary buffers on each + module's own device. + + HF `from_pretrained` in transformers 5.x leaves ``inv_freq`` uninitialized + for models whose rotary buffers are computed in ``__init__`` and never + saved to the state dict (e.g. nemotron-nas, gemma3). With + ``device_map='auto'`` each rotary module can live on a different GPU, so + we drive the recompute per-module using its own inv_freq device rather + than a single fixed device. + """ + model_type = getattr(model.config, "model_type", None) + if model_type not in _MODELS_REQUIRING_BUFFER_REINIT: + return model + for mod in model.modules(): + inv = getattr(mod, "inv_freq", None) + if inv is None: + continue + mod_device = inv.device + if mod_device.type == "meta": + mod_device = next((p.device for p in mod.parameters()), default_device) + _reinit_non_persistent_buffers(mod, mod_device, model_type=model_type) + return model + + def _fix_meta_rotary_embeddings(model): """Re-materialize RotaryEmbedding tensors stuck on meta device. @@ -370,6 +399,7 @@ def test_checkpoint_robustness(): base_model = _fix_meta_rotary_embeddings( AutoModelForCausalLM.from_pretrained(original_pretrained_path, **hf_kwargs) ).to(device) + _reinit_rotary_per_module(base_model, device) peft_model = PeftModel.from_pretrained(base_model, str(ckpt_step_dir / "model")) hf_logits = _get_logits(peft_model, input_ids, device) @@ -396,6 +426,7 @@ def test_checkpoint_robustness(): hf_model = _fix_meta_rotary_embeddings( AutoModelForCausalLM.from_pretrained(str(consolidated_dir), **hf_kwargs) ).to(device) + _reinit_rotary_per_module(hf_model, device) hf_logits = _get_logits(hf_model, input_ids, device) del hf_model From d8c230471c5aee1f240e96aa159b81daf86b514b Mon Sep 17 00:00:00 2001 From: adil-a Date: Tue, 21 Apr 2026 23:54:47 -0700 Subject: [PATCH 2/2] ci(yaml): bump dist timeout to 20min, set resume_loss_threshold=5e-2 for 49B squad peft Hold-overs from the superseded PR #1951 that are independent of the rotary reinit fix: - timeout_minutes 1 -> 20: Phase 4 rank-0 HF load of the 49B base under device_map="auto" can take several minutes; the 1-minute default occasionally trips the NCCL init barrier. - resume_loss_threshold 5e-2: Phase 6 fresh-train vs resume-from-checkpoint loss tolerance. Matches the empirical step-to-step resume diff observed on the 49B PEFT run (~1.7e-02 .. 3.0e-02). hf_kl_threshold remains at the standard 5e-3; the previous bump to 1.5e1 in #1951 was masking the rotary inv_freq bug now fixed in the preceding commit. Signed-off-by: adil-a --- .../nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml index feb2828dda..359686eee4 100644 --- a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml +++ b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml @@ -29,7 +29,7 @@ step_scheduler: dist_env: backend: nccl - timeout_minutes: 1 + timeout_minutes: 20 rng: _target_: nemo_automodel.components.training.rng.StatefulRNG @@ -113,6 +113,7 @@ ci: recipe_owner: HuiyingLi checkpoint_robustness: hf_kl_threshold: 5e-3 + resume_loss_threshold: 5e-2 trust_remote_code: true distributed.tp_size: 2 tokenizer_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5