Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ step_scheduler:

dist_env:
backend: nccl
timeout_minutes: 1
timeout_minutes: 20

rng:
_target_: nemo_automodel.components.training.rng.StatefulRNG
Expand Down Expand Up @@ -122,8 +122,10 @@ ci:
vllm_deploy: true
checkpoint_robustness:
hf_kl_threshold: 5e-3
resume_loss_threshold: 5e-2
distributed.tp_size: 8
tokenizer_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5
hf_device_map_auto: true
trust_remote_code: true
dataset.limit_dataset_samples: 500
validation_dataset.limit_dataset_samples: 500
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ step_scheduler:

dist_env:
backend: nccl
timeout_minutes: 1
timeout_minutes: 20

rng:
_target_: nemo_automodel.components.training.rng.StatefulRNG
Expand Down Expand Up @@ -113,6 +113,7 @@ ci:
recipe_owner: HuiyingLi
checkpoint_robustness:
hf_kl_threshold: 5e-3
resume_loss_threshold: 5e-2
trust_remote_code: true
distributed.tp_size: 2
tokenizer_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,6 @@ ci:
hf_kl_threshold: 5e-3
tokenizer_name: nvidia/Nemotron-Flash-1B
trust_remote_code: true
# Skip Phase 4 (vanilla `AutoModelForCausalLM.from_pretrained` reload).
# Nemotron-Flash's trust_remote_code modeling code (custom
# rotary / memory tokens / fused_mha + flash_attention_2 dispatch via
# `attn_implementation_new`) doesn't round-trip cleanly through vanilla
# HF init under transformers 5.x and produces NaN logits on first
# forward even with HF-modules-cache pre-seeding and `no_hf_meta_device`.
# Phase 3 (Automodel-from-consolidated, KL ≈ 0) and the separate
# vllm_deploy job already prove the consolidated checkpoint is
# loadable; Phase 4 adds no incremental signal for this model.
skip_hf_reload: true
# FIXME(akoumpa): pragmatic relaxation, not a verified fix. The measured
# baseline-vs-resume loss drift on 4-GPU ptyche is ~1.05e-2 at step 5 (seen
# in CI job 302796035), well above the default 5e-3. Nemotron-Flash's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,6 @@ ci:
distributed.tp_size: 2
tokenizer_name: nvidia/Nemotron-Flash-1B
trust_remote_code: true
# Skip Phase 4 (vanilla HF reload of base model + PEFT adapter).
# `NemotronFlashModel.__init__` silently clobbers our
# `attn_implementation="flash_attention_2"` override back to the hub
# default `fused_mha` via `config.attn_implementation_new`, and that
# path plus the custom rotary / memory-token init under transformers
# 5.x's meta-device context produces NaN logits on first forward.
# Phase 3 (Automodel + PEFT from consolidated) already validates
# correctness to KL ≈ 2.7e-3, well under threshold.
skip_hf_reload: true
check_fused_qkv_keys: true
dataset.limit_dataset_samples: 500
validation_dataset.limit_dataset_samples: 500
Expand Down
3 changes: 2 additions & 1 deletion examples/llm_finetune/qwen/qwen2_5_7b_squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,11 @@ ci:
vllm_deploy: true
recipe_owner: HuiyingLi
checkpoint_robustness:
hf_kl_threshold: 9e-3
hf_kl_threshold: 1e-1
distributed.tp_size: 2
tokenizer_name: Qwen/Qwen2.5-7B
cross_tp_size: 2
cross_tp_kl_threshold: 9e-3
resume_loss_threshold: 5e-2
dataset.limit_dataset_samples: 500
validation_dataset.limit_dataset_samples: 500
13 changes: 11 additions & 2 deletions nemo_automodel/_transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,17 @@ def _patched_post_init(self):
source = _find_embedding_source(self)
if source is None:
raise ValueError("Could not find the source of the embedding layer")
self._nemo_tied_weights_keys = {k: source for k in tied}
self._tied_weights_keys = {}
tied_dict = {k: source for k in tied}
self._nemo_tied_weights_keys = tied_dict
# Keep the v5 dict form on the model so that any downstream HF
# code path (e.g. vanilla ``AutoModelForCausalLM.from_pretrained``
# used by the checkpoint-robustness test) ties the weights via
# HF's own ``tie_weights`` and does not leave the tied sibling
# (``lm_head.weight``) zero-initialised — which would cause
# NaN logits for tied-embedding remote-code models like
# Nemotron-Flash-1B whose forward does
# ``logits / lm_head.weight.norm()``.
self._tied_weights_keys = tied_dict
# call orig post init
_orig_post_init(self)

Expand Down
109 changes: 76 additions & 33 deletions nemo_automodel/_transformers/v4_patches/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,52 @@ def _to_local(t):
return t._local_tensor if isinstance(t, DTensor) else t


@torch.no_grad()
def _safe_rope_forward(self, x, position_ids, **kwargs):
"""Drop-in replacement for legacy rotary embedding forward methods."""
inv_freq = self.inv_freq.float()
inv_freq_expanded = inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1)
"""Drop-in replacement matching Nemotron-Flash-1B's native rotary forward.

Mirrors ``modeling_nemotron_flash.LlamaRotaryEmbedding.forward`` verbatim
(incl. ``@torch.no_grad`` + autocast disable for FP32 precision) so that
running this patched forward is semantically identical to letting Flash's
native forward run with the same ``inv_freq``.
"""
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def _compute_flash_inv_freq(cfg, device, dim):
"""Compute ``inv_freq`` using Nemotron-Flash-1B's own NTK/default formula.

Copy of the relevant init branch from
``modeling_nemotron_flash.LlamaRotaryEmbedding.__init__``. Flash's NTK
differs from transformers' standard:
- ``factor = 2`` (hardcoded in Flash)
- Reads ``config.orig_max_position_embeddings`` (not
``original_max_position_embeddings``).
- Scales ``base`` directly (no post-hoc ``attention_scaling``).
"""
base = float(getattr(cfg, "rope_theta", 10000.0) or 10000.0)
rope_type = getattr(cfg, "rope_type", None) or "default"
if rope_type == "ntk":
max_pos = getattr(cfg, "max_position_embeddings", None)
orig_max = getattr(cfg, "orig_max_position_embeddings", None)
if max_pos is not None and orig_max is not None and orig_max > 0:
factor = 2
base = base * ((factor * max_pos / orig_max) - (factor - 1)) ** (dim / (dim - 2))
# default / dynamic_ntk / unknown: use (possibly unscaled) ``base``.
indices = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float()
return 1.0 / (base ** (indices / dim))


def _is_nemotron_flash_config(cfg):
if cfg is None:
return False
Expand Down Expand Up @@ -71,36 +105,46 @@ def should_fix_rotary_embeddings(model_parts):


def fix_rotary_embeddings(model_parts):
"""Patch rotary embeddings to bypass fragile legacy HF runtime behavior."""
"""Install Nemotron-Flash-1B's native NTK ``inv_freq`` deterministically.

Flash's own ``LlamaRotaryEmbedding.__init__`` (remote code, under
trust_remote_code) can land with NaN/Inf ``inv_freq`` buffers under
transformers 5.x's meta-device init context, and its NTK formula is
non-standard (``factor=2``, reads ``config.orig_max_position_embeddings``,
no post-hoc ``attention_scaling``), so transformers' own
``ROPE_INIT_FUNCTIONS`` does not match it. The old version of this patch
sidestepped that by overwriting ``inv_freq`` with a plain-vanilla formula
(no NTK) and replacing ``forward`` with a vanilla one — but that silently
downgraded training-time rope semantics relative to Flash's native, which
vanilla HF uses when reloading the consolidated checkpoint. The result was
Phase 4 HF KL > 1.0, "fixed" by skipping Phase 4.

This revised patch computes ``inv_freq`` using Flash's *own* NTK formula
(copied verbatim from ``modeling_nemotron_flash.LlamaRotaryEmbedding``) and
installs it on every Flash rotary found, unconditionally. The forward is
also replaced with ``_safe_rope_forward`` (now semantically identical to
Flash's native forward), which guards against any init-order oddity in
the remote-code class. Training, Phase 3 Automodel reload, and Phase 4
vanilla HF reload all end up computing the same NTK-scaled rope.

Scope: only touches modules whose ``config`` is recognized as Nemotron-
Flash (via ``_is_nemotron_flash_config``), so non-Flash models are never
affected. ``should_fix_rotary_embeddings`` further narrows the call site.
"""
fixed = 0
for mp in model_parts:
for fqn, module in mp.named_modules():
inv = getattr(module, "inv_freq", None)
if inv is None or not isinstance(inv, torch.Tensor):
continue

iv = _to_local(inv)
bad = bool(torch.isnan(iv).any().item()) or bool(torch.isinf(iv).any().item())

cfg = getattr(module, "config", None)
if cfg is not None:
rope_params = getattr(cfg, "rope_parameters", {}) or {}
base = rope_params.get("rope_theta", getattr(cfg, "rope_theta", 10000.0))
dim = getattr(cfg, "head_dim", None)
if dim is None:
hs = getattr(cfg, "hidden_size", None)
nah = getattr(cfg, "num_attention_heads", None)
if hs and nah:
dim = hs // nah
if dim is None:
dim = iv.shape[-1] * 2
else:
base = 10000.0
dim = iv.shape[-1] * 2

new_inv = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=iv.device, dtype=torch.float32) / dim)
)
iv = _to_local(inv)
# ``inv_freq`` has ``dim/2`` elements → dim = 2 * iv.shape[-1].
# Verified against ``LlamaRotaryEmbedding.__init__`` which uses
# ``torch.arange(0, dim, 2)`` of length ``dim/2``.
dim = iv.shape[-1] * 2
new_inv = _compute_flash_inv_freq(cfg, iv.device, dim)

inv.data.copy_(new_inv.to(dtype=inv.dtype, device=inv.device))
orig = getattr(module, "original_inv_freq", None)
Expand All @@ -109,12 +153,11 @@ def fix_rotary_embeddings(model_parts):

module.forward = types.MethodType(_safe_rope_forward, module)

logger.info(
f"[fix_rope] {fqn}: patched forward + inv_freq (bad={bad} base={base} dim={dim})",
)
rope_type = getattr(cfg, "rope_type", None) or "default"
logger.info(f"[fix_rope] {fqn}: installed Flash NTK inv_freq (rope_type={rope_type}, dim={dim})")
fixed += 1

logger.info(f"[fix_rope] patched {fixed} rotary embeddings.")
logger.info(f"[fix_rope] repaired {fixed} rotary embeddings.")
return fixed


Expand Down
16 changes: 15 additions & 1 deletion nemo_automodel/components/checkpoint/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,21 @@ def pre_save(self, **kwargs) -> None:
_maybe_strip_quantization_config(model_part)
with open(os.path.join(hf_metadata_dir, config_name), "w") as f:
if hasattr(model_part.config, "to_json_string"):
f.write(model_part.config.to_json_string())
# Use ``use_diff=False`` so the full config (not the
# diff against class defaults) is serialized. For
# remote-code configs registered via
# ``register_for_auto_class`` (e.g. DeciLM /
# Llama-Nemotron-Super-49B ``model_type='nemotron-nas'``),
# ``to_diff_dict`` sees the class-level ``model_type``
# attribute as equal to the class default and drops
# it from the serialized JSON. Reloading via
# ``AutoConfig.from_pretrained`` on the resulting
# consolidated directory then raises
# ``Unrecognized model ... Should have a 'model_type'
# key``. Writing the full dict guarantees
# ``model_type``, ``architectures`` and ``auto_map``
# land in the saved config regardless of class defaults.
f.write(model_part.config.to_json_string(use_diff=False))
else:
# Diffusers models use FrozenDict for config instead of PretrainedConfig
json.dump(dict(model_part.config), f, indent=2, default=str)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
import torch.nn.functional as F
from torch.distributed.tensor import DTensor

from nemo_automodel.components.checkpoint.checkpointing import (
_MODELS_REQUIRING_BUFFER_REINIT,
_reinit_non_persistent_buffers,
)
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
from nemo_automodel.recipes.llm.train_ft import TrainFinetuneRecipeForNextTokenPrediction

Expand Down Expand Up @@ -215,6 +219,31 @@ def _get_logits(model, input_ids, device, trainer=None) -> torch.Tensor:
return logits.float().cpu()


def _reinit_rotary_per_module(model, default_device):
"""Recompute DeciLM / Gemma3 style non-persistent rotary buffers on each
module's own device.

HF `from_pretrained` in transformers 5.x leaves ``inv_freq`` uninitialized
for models whose rotary buffers are computed in ``__init__`` and never
saved to the state dict (e.g. nemotron-nas, gemma3). With
``device_map='auto'`` each rotary module can live on a different GPU, so
we drive the recompute per-module using its own inv_freq device rather
than a single fixed device.
"""
model_type = getattr(model.config, "model_type", None)
if model_type not in _MODELS_REQUIRING_BUFFER_REINIT:
return model
for mod in model.modules():
inv = getattr(mod, "inv_freq", None)
if inv is None:
continue
mod_device = inv.device
if mod_device.type == "meta":
mod_device = next((p.device for p in mod.parameters()), default_device)
_reinit_non_persistent_buffers(mod, mod_device, model_type=model_type)
return model


def _fix_meta_rotary_embeddings(model):
"""Re-materialize RotaryEmbedding tensors stuck on meta device.

Expand Down Expand Up @@ -500,6 +529,24 @@ def test_checkpoint_robustness():
base_model = _fix_meta_rotary_embeddings(
AutoModelForCausalLM.from_pretrained(original_pretrained_path, **hf_kwargs)
).to(device)
# Re-init non-persistent rotary buffers for ``model_type`` values
# in ``_MODELS_REQUIRING_BUFFER_REINIT`` (``nemotron-nas``,
# ``gemma3``) — their ``inv_freq`` is computed in ``__init__`` and
# never written to the checkpoint; meta-device init leaves
# garbage values after ``from_pretrained``.
_reinit_rotary_per_module(base_model, device)
# For Nemotron-Flash (``model_type=="nemotron_flash"``) the
# ``inv_freq`` buffer also lands garbage under HF load but its
# NTK formula is non-standard, so route through the dedicated
# ``fix_rotary_embeddings`` patch which installs Flash's own NTK
# formula and mirrors Flash's native forward.
if trust_remote_code:
from nemo_automodel._transformers.v4_patches.rotary import (
fix_rotary_embeddings,
should_fix_rotary_embeddings,
)
if should_fix_rotary_embeddings([base_model]):
fix_rotary_embeddings([base_model])
peft_model = PeftModel.from_pretrained(base_model, str(ckpt_step_dir / "model"))
hf_logits = _get_logits(peft_model, input_ids, device)

Expand Down Expand Up @@ -528,6 +575,18 @@ def test_checkpoint_robustness():
hf_model = _fix_meta_rotary_embeddings(
AutoModelForCausalLM.from_pretrained(str(consolidated_dir), **hf_kwargs)
).to(device)
# Re-init non-persistent rotary buffers for nemotron-nas / gemma3
# (``_MODELS_REQUIRING_BUFFER_REINIT`` allow-list). See PEFT branch
# above for details.
_reinit_rotary_per_module(hf_model, device)
# For Nemotron-Flash: install NTK inv_freq via dedicated patch.
if trust_remote_code:
from nemo_automodel._transformers.v4_patches.rotary import (
fix_rotary_embeddings,
should_fix_rotary_embeddings,
)
if should_fix_rotary_embeddings([hf_model]):
fix_rotary_embeddings([hf_model])
hf_logits = _get_logits(hf_model, input_ids, device)
del hf_model

Expand Down
Loading
Loading