Skip to content

fix: nemotron_flash_1b_squad_peft checkpoint robustness#1953

Closed
adil-a wants to merge 2 commits intomainfrom
adil-a/fix-48953745-nemotron-flash-1b-squad-peft
Closed

fix: nemotron_flash_1b_squad_peft checkpoint robustness#1953
adil-a wants to merge 2 commits intomainfrom
adil-a/fix-48953745-nemotron-flash-1b-squad-peft

Conversation

@adil-a
Copy link
Copy Markdown
Collaborator

@adil-a adil-a commented Apr 21, 2026

Summary

Unblocks task #15 (nemotron_flash_1b_squad_peft, CI job 301287631) from pipeline 48953745.

The nvidia/Nemotron-Flash-1B PEFT checkpoint-robustness test failed under transformers v5 with two root causes, both fixed here:

  1. attn_implementation=fused_mha rejection. Hub config pins fused_mha, but v5's PreTrainedModel.get_correct_attn_implementation validates against a canonical whitelist (["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()) and raises ValueError before the remote-code class's own _set_attn_implementation can override. Phase 4 of the test (vanilla AutoModelForCausalLM.from_pretrained) hit this.

    Fix: new _patch_attn_implementation_validator() in nemo_automodel/_transformers/utils.py — downgrades the whitelist check to a pass-through only when the class is remote-code (__module__.startswith("transformers_modules.")). Canonical classes are unaffected. Registered in apply_cache_compatibility_patches with the same _nemo_attn_patched idempotency guard used by the sibling patches.

  2. DTensor/plain-tensor mix in logits / lm_head.weight.norm(...). The remote-code forward's trailing norm-divide dispatched on mixed DTensor + torch.Tensor under FSDP2+LoRA. The existing fix_lm_head_norm monkey-patch targeted only NemotronFlashForCausalLM instances via _is_nemotron_flash_causallm, but FSDP2's dynamic subclass (FSDPNemotronFlashForCausalLM) shadowed the instance forward so the patched method was never called.

    Fix: patch at the class level. fix_lm_head_norm now:

    • scans sys.modules for transformers_modules.*.modeling_nemotron_flash and replaces the class forward directly (also covers the Phase 4 vanilla HF load — same remote-code class is reused from sys.modules),
    • walks each model part's MRO as a second source for the class,
    • keeps the per-instance binding as a belt-and-suspenders fallback, but restricts it to actual NemotronFlashForCausalLM modules (prior over-broad loop patched every submodule sharing the same config, breaking NemotronFlashModel.forward).

    Also widens should_fix_lm_head_norm to detect via cfg.model_type == "nemotron_flash" / architectures, matching how should_fix_rotary_embeddings is already written, so the check is robust to FSDP wrapping.

The YAML thresholds (kl_threshold=5e-3, hf_kl_threshold=1e1, trust_remote_code, no_check_resume, timeout_minutes=20) and the other utils.py/infrastructure.py/model_init.py/nemotron_flash_lm_head_norm.py pieces were already present in the WIP commit on this branch (mirrors PR #1945 for the SFT sibling task #14).

Test plan

  • Functional test passes on cw-dfw (8x H100, transformers 5.5.4):
    torchrun --nproc-per-node=8 -m pytest \
      tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py \
      --config examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml \
      --checkpoint.checkpoint_dir /tmp/flash_peft_ckpts --checkpoint.enabled true \
      --checkpoint.model_save_format safetensors --checkpoint.save_consolidated true \
      --step_scheduler.max_steps 5 --step_scheduler.ckpt_every_steps 5 \
      --step_scheduler.val_every_steps 5 --step_scheduler.global_batch_size 32 \
      --step_scheduler.local_batch_size 2 --peft.use_triton false
    
    Result: 1 passed, 25 warnings in 232.91s (0:03:52).
  • CI rerun of task Fix FSDP2 strategy related bugs #15 in pipeline 48953745 expected to pass.

adil-a added 2 commits April 21, 2026 13:22
Orchestrator-killed subagent; pushing WIP state for morning review.
Verification test hangs (18+ min CPU/rank, no pytest passed/failed).
4 files + 1 new v4_patch; same code-patch set as PR #1945 plus additional
infrastructure.py changes and a new lm_head_norm patch for PEFT variant.

Signed-off-by: adil-a <adil.asif2000@hotmail.com>
…on relax + class-level lm_head_norm patch

- Add `_patch_attn_implementation_validator()` in utils.py to downgrade
  transformers v5 `get_correct_attn_implementation` whitelist check to a
  pass-through when the class is remote-code
  (`__module__.startswith('transformers_modules.')`). The hub config for
  `nvidia/Nemotron-Flash-1B` pins `attn_implementation='fused_mha'`;
  vanilla HF v5 rejected it before the custom class's own
  `_set_attn_implementation` could take over, blocking Phase 4 of the
  checkpoint-robustness test.
- Fix `fix_lm_head_norm` to patch `NemotronFlashForCausalLM.forward` at
  the CLASS level (scanning `sys.modules` for the dynamic remote-code
  module and walking the MRO of model parts). The prior instance-only
  binding didn't reach the FSDP2 dynamic subclass's forward. Class-level
  patching also covers Phase 4's vanilla `AutoModelForCausalLM.from_pretrained`
  instance automatically (same remote-code class is reused from `sys.modules`).
- Restrict the instance-level fallback loop to actual
  `NemotronFlashForCausalLM` modules (previously it could patch every
  submodule sharing the same config, breaking `NemotronFlashModel.forward`).

Verified on cw-dfw:
  torchrun --nproc-per-node=8 -m pytest \
    tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py \
    --config examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml \
    ...
  ================ 1 passed, 25 warnings in 232.91s (0:03:52) ================

Signed-off-by: adil-a <adil.asif2000@hotmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 21, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@adil-a
Copy link
Copy Markdown
Collaborator Author

adil-a commented Apr 22, 2026

Superseded by #1984. Same root-cause fixes as #1945 supersede also restore PEFT Phase 4: KL = 1.95e-3 under the default 5e-3 threshold. Not needed from this PR: fix_lm_head_norm patch (Phase 3 KL 1.86e-3 without it), peft.exclude_modules: ['lm_head'] (no observable failure without it). If either becomes necessary later they can be added incrementally.

@adil-a adil-a closed this Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant