Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ step_scheduler:

dist_env:
backend: nccl
timeout_minutes: 1
timeout_minutes: 20

rng:
_target_: nemo_automodel.components.training.rng.StatefulRNG
Expand Down Expand Up @@ -113,6 +113,7 @@ ci:
recipe_owner: HuiyingLi
checkpoint_robustness:
hf_kl_threshold: 5e-3
resume_loss_threshold: 5e-2
trust_remote_code: true
distributed.tp_size: 2
tokenizer_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
import torch.nn.functional as F
from torch.distributed.tensor import DTensor

from nemo_automodel.components.checkpoint.checkpointing import (
_MODELS_REQUIRING_BUFFER_REINIT,
_reinit_non_persistent_buffers,
)
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
from nemo_automodel.recipes.llm.train_ft import TrainFinetuneRecipeForNextTokenPrediction

Expand Down Expand Up @@ -214,6 +218,31 @@ def _get_logits(model, input_ids, device, trainer=None) -> torch.Tensor:
return logits.float().cpu()


def _reinit_rotary_per_module(model, default_device):
"""Recompute DeciLM / Gemma3 style non-persistent rotary buffers on each
module's own device.

HF `from_pretrained` in transformers 5.x leaves ``inv_freq`` uninitialized
for models whose rotary buffers are computed in ``__init__`` and never
saved to the state dict (e.g. nemotron-nas, gemma3). With
``device_map='auto'`` each rotary module can live on a different GPU, so
we drive the recompute per-module using its own inv_freq device rather
than a single fixed device.
"""
model_type = getattr(model.config, "model_type", None)
if model_type not in _MODELS_REQUIRING_BUFFER_REINIT:
return model
for mod in model.modules():
inv = getattr(mod, "inv_freq", None)
if inv is None:
continue
mod_device = inv.device
if mod_device.type == "meta":
mod_device = next((p.device for p in mod.parameters()), default_device)
_reinit_non_persistent_buffers(mod, mod_device, model_type=model_type)
return model


def _fix_meta_rotary_embeddings(model):
"""Re-materialize RotaryEmbedding tensors stuck on meta device.

Expand Down Expand Up @@ -370,6 +399,7 @@ def test_checkpoint_robustness():
base_model = _fix_meta_rotary_embeddings(
AutoModelForCausalLM.from_pretrained(original_pretrained_path, **hf_kwargs)
).to(device)
_reinit_rotary_per_module(base_model, device)
peft_model = PeftModel.from_pretrained(base_model, str(ckpt_step_dir / "model"))
hf_logits = _get_logits(peft_model, input_ids, device)

Expand All @@ -396,6 +426,7 @@ def test_checkpoint_robustness():
hf_model = _fix_meta_rotary_embeddings(
AutoModelForCausalLM.from_pretrained(str(consolidated_dir), **hf_kwargs)
).to(device)
_reinit_rotary_per_module(hf_model, device)
hf_logits = _get_logits(hf_model, input_ids, device)
del hf_model

Expand Down
Loading