diff --git a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml index feb2828dda..359686eee4 100644 --- a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml +++ b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml @@ -29,7 +29,7 @@ step_scheduler: dist_env: backend: nccl - timeout_minutes: 1 + timeout_minutes: 20 rng: _target_: nemo_automodel.components.training.rng.StatefulRNG @@ -113,6 +113,7 @@ ci: recipe_owner: HuiyingLi checkpoint_robustness: hf_kl_threshold: 5e-3 + resume_loss_threshold: 5e-2 trust_remote_code: true distributed.tp_size: 2 tokenizer_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 diff --git a/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py b/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py index b2ca493f34..7dbeae938a 100644 --- a/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py +++ b/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py @@ -35,6 +35,10 @@ import torch.nn.functional as F from torch.distributed.tensor import DTensor +from nemo_automodel.components.checkpoint.checkpointing import ( + _MODELS_REQUIRING_BUFFER_REINIT, + _reinit_non_persistent_buffers, +) from nemo_automodel.components.config._arg_parser import parse_args_and_load_config from nemo_automodel.recipes.llm.train_ft import TrainFinetuneRecipeForNextTokenPrediction @@ -214,6 +218,31 @@ def _get_logits(model, input_ids, device, trainer=None) -> torch.Tensor: return logits.float().cpu() +def _reinit_rotary_per_module(model, default_device): + """Recompute DeciLM / Gemma3 style non-persistent rotary buffers on each + module's own device. + + HF `from_pretrained` in transformers 5.x leaves ``inv_freq`` uninitialized + for models whose rotary buffers are computed in ``__init__`` and never + saved to the state dict (e.g. nemotron-nas, gemma3). With + ``device_map='auto'`` each rotary module can live on a different GPU, so + we drive the recompute per-module using its own inv_freq device rather + than a single fixed device. + """ + model_type = getattr(model.config, "model_type", None) + if model_type not in _MODELS_REQUIRING_BUFFER_REINIT: + return model + for mod in model.modules(): + inv = getattr(mod, "inv_freq", None) + if inv is None: + continue + mod_device = inv.device + if mod_device.type == "meta": + mod_device = next((p.device for p in mod.parameters()), default_device) + _reinit_non_persistent_buffers(mod, mod_device, model_type=model_type) + return model + + def _fix_meta_rotary_embeddings(model): """Re-materialize RotaryEmbedding tensors stuck on meta device. @@ -370,6 +399,7 @@ def test_checkpoint_robustness(): base_model = _fix_meta_rotary_embeddings( AutoModelForCausalLM.from_pretrained(original_pretrained_path, **hf_kwargs) ).to(device) + _reinit_rotary_per_module(base_model, device) peft_model = PeftModel.from_pretrained(base_model, str(ckpt_step_dir / "model")) hf_logits = _get_logits(peft_model, input_ids, device) @@ -396,6 +426,7 @@ def test_checkpoint_robustness(): hf_model = _fix_meta_rotary_embeddings( AutoModelForCausalLM.from_pretrained(str(consolidated_dir), **hf_kwargs) ).to(device) + _reinit_rotary_per_module(hf_model, device) hf_logits = _get_logits(hf_model, input_ids, device) del hf_model