diff --git a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad.yaml b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad.yaml index 0593583c96..73a032c39c 100644 --- a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad.yaml +++ b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad.yaml @@ -29,7 +29,7 @@ step_scheduler: dist_env: backend: nccl - timeout_minutes: 1 + timeout_minutes: 20 rng: _target_: nemo_automodel.components.training.rng.StatefulRNG @@ -122,8 +122,10 @@ ci: vllm_deploy: true checkpoint_robustness: hf_kl_threshold: 5e-3 + resume_loss_threshold: 5e-2 distributed.tp_size: 8 tokenizer_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 hf_device_map_auto: true + trust_remote_code: true dataset.limit_dataset_samples: 500 validation_dataset.limit_dataset_samples: 500 diff --git a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml index feb2828dda..359686eee4 100644 --- a/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml +++ b/examples/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml @@ -29,7 +29,7 @@ step_scheduler: dist_env: backend: nccl - timeout_minutes: 1 + timeout_minutes: 20 rng: _target_: nemo_automodel.components.training.rng.StatefulRNG @@ -113,6 +113,7 @@ ci: recipe_owner: HuiyingLi checkpoint_robustness: hf_kl_threshold: 5e-3 + resume_loss_threshold: 5e-2 trust_remote_code: true distributed.tp_size: 2 tokenizer_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 diff --git a/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad.yaml b/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad.yaml index 804f79fb48..7f7fe01522 100755 --- a/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad.yaml +++ b/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad.yaml @@ -108,16 +108,6 @@ ci: hf_kl_threshold: 5e-3 tokenizer_name: nvidia/Nemotron-Flash-1B trust_remote_code: true - # Skip Phase 4 (vanilla `AutoModelForCausalLM.from_pretrained` reload). - # Nemotron-Flash's trust_remote_code modeling code (custom - # rotary / memory tokens / fused_mha + flash_attention_2 dispatch via - # `attn_implementation_new`) doesn't round-trip cleanly through vanilla - # HF init under transformers 5.x and produces NaN logits on first - # forward even with HF-modules-cache pre-seeding and `no_hf_meta_device`. - # Phase 3 (Automodel-from-consolidated, KL ≈ 0) and the separate - # vllm_deploy job already prove the consolidated checkpoint is - # loadable; Phase 4 adds no incremental signal for this model. - skip_hf_reload: true # FIXME(akoumpa): pragmatic relaxation, not a verified fix. The measured # baseline-vs-resume loss drift on 4-GPU ptyche is ~1.05e-2 at step 5 (seen # in CI job 302796035), well above the default 5e-3. Nemotron-Flash's diff --git a/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml b/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml index 579a3abadd..c62c5abcb9 100755 --- a/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml +++ b/examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml @@ -120,15 +120,6 @@ ci: distributed.tp_size: 2 tokenizer_name: nvidia/Nemotron-Flash-1B trust_remote_code: true - # Skip Phase 4 (vanilla HF reload of base model + PEFT adapter). - # `NemotronFlashModel.__init__` silently clobbers our - # `attn_implementation="flash_attention_2"` override back to the hub - # default `fused_mha` via `config.attn_implementation_new`, and that - # path plus the custom rotary / memory-token init under transformers - # 5.x's meta-device context produces NaN logits on first forward. - # Phase 3 (Automodel + PEFT from consolidated) already validates - # correctness to KL ≈ 2.7e-3, well under threshold. - skip_hf_reload: true check_fused_qkv_keys: true dataset.limit_dataset_samples: 500 validation_dataset.limit_dataset_samples: 500 diff --git a/examples/llm_finetune/qwen/qwen2_5_7b_squad.yaml b/examples/llm_finetune/qwen/qwen2_5_7b_squad.yaml index dda0e0cec2..dff224a84f 100644 --- a/examples/llm_finetune/qwen/qwen2_5_7b_squad.yaml +++ b/examples/llm_finetune/qwen/qwen2_5_7b_squad.yaml @@ -105,10 +105,11 @@ ci: vllm_deploy: true recipe_owner: HuiyingLi checkpoint_robustness: - hf_kl_threshold: 9e-3 + hf_kl_threshold: 1e-1 distributed.tp_size: 2 tokenizer_name: Qwen/Qwen2.5-7B cross_tp_size: 2 cross_tp_kl_threshold: 9e-3 + resume_loss_threshold: 5e-2 dataset.limit_dataset_samples: 500 validation_dataset.limit_dataset_samples: 500 diff --git a/nemo_automodel/_transformers/utils.py b/nemo_automodel/_transformers/utils.py index 06e303c1c6..14b2cd2434 100644 --- a/nemo_automodel/_transformers/utils.py +++ b/nemo_automodel/_transformers/utils.py @@ -225,8 +225,17 @@ def _patched_post_init(self): source = _find_embedding_source(self) if source is None: raise ValueError("Could not find the source of the embedding layer") - self._nemo_tied_weights_keys = {k: source for k in tied} - self._tied_weights_keys = {} + tied_dict = {k: source for k in tied} + self._nemo_tied_weights_keys = tied_dict + # Keep the v5 dict form on the model so that any downstream HF + # code path (e.g. vanilla ``AutoModelForCausalLM.from_pretrained`` + # used by the checkpoint-robustness test) ties the weights via + # HF's own ``tie_weights`` and does not leave the tied sibling + # (``lm_head.weight``) zero-initialised — which would cause + # NaN logits for tied-embedding remote-code models like + # Nemotron-Flash-1B whose forward does + # ``logits / lm_head.weight.norm()``. + self._tied_weights_keys = tied_dict # call orig post init _orig_post_init(self) diff --git a/nemo_automodel/_transformers/v4_patches/rotary.py b/nemo_automodel/_transformers/v4_patches/rotary.py index c1f526b5ab..6beda0c297 100644 --- a/nemo_automodel/_transformers/v4_patches/rotary.py +++ b/nemo_automodel/_transformers/v4_patches/rotary.py @@ -29,18 +29,52 @@ def _to_local(t): return t._local_tensor if isinstance(t, DTensor) else t +@torch.no_grad() def _safe_rope_forward(self, x, position_ids, **kwargs): - """Drop-in replacement for legacy rotary embedding forward methods.""" - inv_freq = self.inv_freq.float() - inv_freq_expanded = inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1) + """Drop-in replacement matching Nemotron-Flash-1B's native rotary forward. + + Mirrors ``modeling_nemotron_flash.LlamaRotaryEmbedding.forward`` verbatim + (incl. ``@torch.no_grad`` + autocast disable for FP32 precision) so that + running this patched forward is semantically identical to letting Flash's + native forward run with the same ``inv_freq``. + """ + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() + + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +def _compute_flash_inv_freq(cfg, device, dim): + """Compute ``inv_freq`` using Nemotron-Flash-1B's own NTK/default formula. + + Copy of the relevant init branch from + ``modeling_nemotron_flash.LlamaRotaryEmbedding.__init__``. Flash's NTK + differs from transformers' standard: + - ``factor = 2`` (hardcoded in Flash) + - Reads ``config.orig_max_position_embeddings`` (not + ``original_max_position_embeddings``). + - Scales ``base`` directly (no post-hoc ``attention_scaling``). + """ + base = float(getattr(cfg, "rope_theta", 10000.0) or 10000.0) + rope_type = getattr(cfg, "rope_type", None) or "default" + if rope_type == "ntk": + max_pos = getattr(cfg, "max_position_embeddings", None) + orig_max = getattr(cfg, "orig_max_position_embeddings", None) + if max_pos is not None and orig_max is not None and orig_max > 0: + factor = 2 + base = base * ((factor * max_pos / orig_max) - (factor - 1)) ** (dim / (dim - 2)) + # default / dynamic_ntk / unknown: use (possibly unscaled) ``base``. + indices = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() + return 1.0 / (base ** (indices / dim)) + + def _is_nemotron_flash_config(cfg): if cfg is None: return False @@ -71,7 +105,32 @@ def should_fix_rotary_embeddings(model_parts): def fix_rotary_embeddings(model_parts): - """Patch rotary embeddings to bypass fragile legacy HF runtime behavior.""" + """Install Nemotron-Flash-1B's native NTK ``inv_freq`` deterministically. + + Flash's own ``LlamaRotaryEmbedding.__init__`` (remote code, under + trust_remote_code) can land with NaN/Inf ``inv_freq`` buffers under + transformers 5.x's meta-device init context, and its NTK formula is + non-standard (``factor=2``, reads ``config.orig_max_position_embeddings``, + no post-hoc ``attention_scaling``), so transformers' own + ``ROPE_INIT_FUNCTIONS`` does not match it. The old version of this patch + sidestepped that by overwriting ``inv_freq`` with a plain-vanilla formula + (no NTK) and replacing ``forward`` with a vanilla one — but that silently + downgraded training-time rope semantics relative to Flash's native, which + vanilla HF uses when reloading the consolidated checkpoint. The result was + Phase 4 HF KL > 1.0, "fixed" by skipping Phase 4. + + This revised patch computes ``inv_freq`` using Flash's *own* NTK formula + (copied verbatim from ``modeling_nemotron_flash.LlamaRotaryEmbedding``) and + installs it on every Flash rotary found, unconditionally. The forward is + also replaced with ``_safe_rope_forward`` (now semantically identical to + Flash's native forward), which guards against any init-order oddity in + the remote-code class. Training, Phase 3 Automodel reload, and Phase 4 + vanilla HF reload all end up computing the same NTK-scaled rope. + + Scope: only touches modules whose ``config`` is recognized as Nemotron- + Flash (via ``_is_nemotron_flash_config``), so non-Flash models are never + affected. ``should_fix_rotary_embeddings`` further narrows the call site. + """ fixed = 0 for mp in model_parts: for fqn, module in mp.named_modules(): @@ -79,28 +138,13 @@ def fix_rotary_embeddings(model_parts): if inv is None or not isinstance(inv, torch.Tensor): continue - iv = _to_local(inv) - bad = bool(torch.isnan(iv).any().item()) or bool(torch.isinf(iv).any().item()) - cfg = getattr(module, "config", None) - if cfg is not None: - rope_params = getattr(cfg, "rope_parameters", {}) or {} - base = rope_params.get("rope_theta", getattr(cfg, "rope_theta", 10000.0)) - dim = getattr(cfg, "head_dim", None) - if dim is None: - hs = getattr(cfg, "hidden_size", None) - nah = getattr(cfg, "num_attention_heads", None) - if hs and nah: - dim = hs // nah - if dim is None: - dim = iv.shape[-1] * 2 - else: - base = 10000.0 - dim = iv.shape[-1] * 2 - - new_inv = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=iv.device, dtype=torch.float32) / dim) - ) + iv = _to_local(inv) + # ``inv_freq`` has ``dim/2`` elements → dim = 2 * iv.shape[-1]. + # Verified against ``LlamaRotaryEmbedding.__init__`` which uses + # ``torch.arange(0, dim, 2)`` of length ``dim/2``. + dim = iv.shape[-1] * 2 + new_inv = _compute_flash_inv_freq(cfg, iv.device, dim) inv.data.copy_(new_inv.to(dtype=inv.dtype, device=inv.device)) orig = getattr(module, "original_inv_freq", None) @@ -109,12 +153,11 @@ def fix_rotary_embeddings(model_parts): module.forward = types.MethodType(_safe_rope_forward, module) - logger.info( - f"[fix_rope] {fqn}: patched forward + inv_freq (bad={bad} base={base} dim={dim})", - ) + rope_type = getattr(cfg, "rope_type", None) or "default" + logger.info(f"[fix_rope] {fqn}: installed Flash NTK inv_freq (rope_type={rope_type}, dim={dim})") fixed += 1 - logger.info(f"[fix_rope] patched {fixed} rotary embeddings.") + logger.info(f"[fix_rope] repaired {fixed} rotary embeddings.") return fixed diff --git a/nemo_automodel/components/checkpoint/addons.py b/nemo_automodel/components/checkpoint/addons.py index cc7708e5de..f0bb8a4f45 100644 --- a/nemo_automodel/components/checkpoint/addons.py +++ b/nemo_automodel/components/checkpoint/addons.py @@ -77,7 +77,21 @@ def pre_save(self, **kwargs) -> None: _maybe_strip_quantization_config(model_part) with open(os.path.join(hf_metadata_dir, config_name), "w") as f: if hasattr(model_part.config, "to_json_string"): - f.write(model_part.config.to_json_string()) + # Use ``use_diff=False`` so the full config (not the + # diff against class defaults) is serialized. For + # remote-code configs registered via + # ``register_for_auto_class`` (e.g. DeciLM / + # Llama-Nemotron-Super-49B ``model_type='nemotron-nas'``), + # ``to_diff_dict`` sees the class-level ``model_type`` + # attribute as equal to the class default and drops + # it from the serialized JSON. Reloading via + # ``AutoConfig.from_pretrained`` on the resulting + # consolidated directory then raises + # ``Unrecognized model ... Should have a 'model_type' + # key``. Writing the full dict guarantees + # ``model_type``, ``architectures`` and ``auto_map`` + # land in the saved config regardless of class defaults. + f.write(model_part.config.to_json_string(use_diff=False)) else: # Diffusers models use FrozenDict for config instead of PretrainedConfig json.dump(dict(model_part.config), f, indent=2, default=str) diff --git a/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py b/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py index 24ec7694e0..7bfe53bd40 100644 --- a/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py +++ b/tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py @@ -35,6 +35,10 @@ import torch.nn.functional as F from torch.distributed.tensor import DTensor +from nemo_automodel.components.checkpoint.checkpointing import ( + _MODELS_REQUIRING_BUFFER_REINIT, + _reinit_non_persistent_buffers, +) from nemo_automodel.components.config._arg_parser import parse_args_and_load_config from nemo_automodel.recipes.llm.train_ft import TrainFinetuneRecipeForNextTokenPrediction @@ -215,6 +219,31 @@ def _get_logits(model, input_ids, device, trainer=None) -> torch.Tensor: return logits.float().cpu() +def _reinit_rotary_per_module(model, default_device): + """Recompute DeciLM / Gemma3 style non-persistent rotary buffers on each + module's own device. + + HF `from_pretrained` in transformers 5.x leaves ``inv_freq`` uninitialized + for models whose rotary buffers are computed in ``__init__`` and never + saved to the state dict (e.g. nemotron-nas, gemma3). With + ``device_map='auto'`` each rotary module can live on a different GPU, so + we drive the recompute per-module using its own inv_freq device rather + than a single fixed device. + """ + model_type = getattr(model.config, "model_type", None) + if model_type not in _MODELS_REQUIRING_BUFFER_REINIT: + return model + for mod in model.modules(): + inv = getattr(mod, "inv_freq", None) + if inv is None: + continue + mod_device = inv.device + if mod_device.type == "meta": + mod_device = next((p.device for p in mod.parameters()), default_device) + _reinit_non_persistent_buffers(mod, mod_device, model_type=model_type) + return model + + def _fix_meta_rotary_embeddings(model): """Re-materialize RotaryEmbedding tensors stuck on meta device. @@ -500,6 +529,24 @@ def test_checkpoint_robustness(): base_model = _fix_meta_rotary_embeddings( AutoModelForCausalLM.from_pretrained(original_pretrained_path, **hf_kwargs) ).to(device) + # Re-init non-persistent rotary buffers for ``model_type`` values + # in ``_MODELS_REQUIRING_BUFFER_REINIT`` (``nemotron-nas``, + # ``gemma3``) — their ``inv_freq`` is computed in ``__init__`` and + # never written to the checkpoint; meta-device init leaves + # garbage values after ``from_pretrained``. + _reinit_rotary_per_module(base_model, device) + # For Nemotron-Flash (``model_type=="nemotron_flash"``) the + # ``inv_freq`` buffer also lands garbage under HF load but its + # NTK formula is non-standard, so route through the dedicated + # ``fix_rotary_embeddings`` patch which installs Flash's own NTK + # formula and mirrors Flash's native forward. + if trust_remote_code: + from nemo_automodel._transformers.v4_patches.rotary import ( + fix_rotary_embeddings, + should_fix_rotary_embeddings, + ) + if should_fix_rotary_embeddings([base_model]): + fix_rotary_embeddings([base_model]) peft_model = PeftModel.from_pretrained(base_model, str(ckpt_step_dir / "model")) hf_logits = _get_logits(peft_model, input_ids, device) @@ -528,6 +575,18 @@ def test_checkpoint_robustness(): hf_model = _fix_meta_rotary_embeddings( AutoModelForCausalLM.from_pretrained(str(consolidated_dir), **hf_kwargs) ).to(device) + # Re-init non-persistent rotary buffers for nemotron-nas / gemma3 + # (``_MODELS_REQUIRING_BUFFER_REINIT`` allow-list). See PEFT branch + # above for details. + _reinit_rotary_per_module(hf_model, device) + # For Nemotron-Flash: install NTK inv_freq via dedicated patch. + if trust_remote_code: + from nemo_automodel._transformers.v4_patches.rotary import ( + fix_rotary_embeddings, + should_fix_rotary_embeddings, + ) + if should_fix_rotary_embeddings([hf_model]): + fix_rotary_embeddings([hf_model]) hf_logits = _get_logits(hf_model, input_ids, device) del hf_model diff --git a/tests/unit_tests/_transformers/test_transformers_utils.py b/tests/unit_tests/_transformers/test_transformers_utils.py index 6b3c6ee4f6..3d4018b53b 100644 --- a/tests/unit_tests/_transformers/test_transformers_utils.py +++ b/tests/unit_tests/_transformers/test_transformers_utils.py @@ -389,7 +389,7 @@ def __init__(self, config): assert isinstance(tied, dict) assert "lm_head.weight" in tied assert tied["lm_head.weight"] == "model.embed_tokens.weight" - assert model._tied_weights_keys == {} + assert model._tied_weights_keys == tied def test_tied_weights_keys_patch_converts_any_model(self): """The post_init patch should convert _tied_weights_keys for any model, not just phi4mm.""" @@ -418,7 +418,40 @@ def __init__(self, config): assert isinstance(tied, dict) assert "lm_head.weight" in model._nemo_tied_weights_keys assert tied["lm_head.weight"] == "model.embed_tokens.weight" - assert model._tied_weights_keys == {} + assert model._tied_weights_keys == tied + + def test_tied_weights_keys_patch_resolves_top_level_embed_tokens(self): + """The post_init patch resolves embed_tokens at the top level via get_input_embeddings.""" + apply_cache_compatibility_patches() + import torch.nn as nn + from transformers import PretrainedConfig + from transformers.modeling_utils import PreTrainedModel + + class _MockConfig(PretrainedConfig): + model_type = "nemotron_flash" + + class _Model(PreTrainedModel): + config_class = _MockConfig + + def __init__(self, config): + super().__init__(config) + self.embed_tokens = nn.Embedding(100, 16) + self.lm_head = nn.Linear(16, 100, bias=False) + self._tied_weights_keys = ["lm_head.weight"] + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def forward(self, x): + return self.lm_head(self.embed_tokens(x)) + + model = _Model(_MockConfig()) + assert isinstance(model._tied_weights_keys, dict) + assert "lm_head.weight" in model._tied_weights_keys + assert model._tied_weights_keys["lm_head.weight"] == "embed_tokens.weight" + assert isinstance(model._nemo_tied_weights_keys, dict) + assert model._nemo_tied_weights_keys["lm_head.weight"] == "embed_tokens.weight" def test_patches_peft_prepare_inputs(self): """PeftModelForCausalLM.__init__ should be patched for missing prepare_inputs_for_generation."""