fix: nemotron_flash_1b_squad_peft checkpoint robustness#1953
Closed
fix: nemotron_flash_1b_squad_peft checkpoint robustness#1953
Conversation
Orchestrator-killed subagent; pushing WIP state for morning review. Verification test hangs (18+ min CPU/rank, no pytest passed/failed). 4 files + 1 new v4_patch; same code-patch set as PR #1945 plus additional infrastructure.py changes and a new lm_head_norm patch for PEFT variant. Signed-off-by: adil-a <adil.asif2000@hotmail.com>
…on relax + class-level lm_head_norm patch
- Add `_patch_attn_implementation_validator()` in utils.py to downgrade
transformers v5 `get_correct_attn_implementation` whitelist check to a
pass-through when the class is remote-code
(`__module__.startswith('transformers_modules.')`). The hub config for
`nvidia/Nemotron-Flash-1B` pins `attn_implementation='fused_mha'`;
vanilla HF v5 rejected it before the custom class's own
`_set_attn_implementation` could take over, blocking Phase 4 of the
checkpoint-robustness test.
- Fix `fix_lm_head_norm` to patch `NemotronFlashForCausalLM.forward` at
the CLASS level (scanning `sys.modules` for the dynamic remote-code
module and walking the MRO of model parts). The prior instance-only
binding didn't reach the FSDP2 dynamic subclass's forward. Class-level
patching also covers Phase 4's vanilla `AutoModelForCausalLM.from_pretrained`
instance automatically (same remote-code class is reused from `sys.modules`).
- Restrict the instance-level fallback loop to actual
`NemotronFlashForCausalLM` modules (previously it could patch every
submodule sharing the same config, breaking `NemotronFlashModel.forward`).
Verified on cw-dfw:
torchrun --nproc-per-node=8 -m pytest \
tests/functional_tests/checkpoint_robustness/test_checkpoint_robustness_llm.py \
--config examples/llm_finetune/nemotron_flash/nemotron_flash_1b_squad_peft.yaml \
...
================ 1 passed, 25 warnings in 232.91s (0:03:52) ================
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
This was referenced Apr 21, 2026
Collaborator
Author
|
Superseded by #1984. Same root-cause fixes as #1945 supersede also restore PEFT Phase 4: KL = 1.95e-3 under the default 5e-3 threshold. Not needed from this PR: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Unblocks task #15 (
nemotron_flash_1b_squad_peft, CI job 301287631) from pipeline 48953745.The
nvidia/Nemotron-Flash-1BPEFT checkpoint-robustness test failed under transformers v5 with two root causes, both fixed here:attn_implementation=fused_mharejection. Hub config pinsfused_mha, but v5'sPreTrainedModel.get_correct_attn_implementationvalidates against a canonical whitelist (["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()) and raisesValueErrorbefore the remote-code class's own_set_attn_implementationcan override. Phase 4 of the test (vanillaAutoModelForCausalLM.from_pretrained) hit this.Fix: new
_patch_attn_implementation_validator()innemo_automodel/_transformers/utils.py— downgrades the whitelist check to a pass-through only when the class is remote-code (__module__.startswith("transformers_modules.")). Canonical classes are unaffected. Registered inapply_cache_compatibility_patcheswith the same_nemo_attn_patchedidempotency guard used by the sibling patches.DTensor/plain-tensor mix in
logits / lm_head.weight.norm(...). The remote-code forward's trailing norm-divide dispatched on mixed DTensor + torch.Tensor under FSDP2+LoRA. The existingfix_lm_head_normmonkey-patch targeted onlyNemotronFlashForCausalLMinstances via_is_nemotron_flash_causallm, but FSDP2's dynamic subclass (FSDPNemotronFlashForCausalLM) shadowed the instanceforwardso the patched method was never called.Fix: patch at the class level.
fix_lm_head_normnow:sys.modulesfortransformers_modules.*.modeling_nemotron_flashand replaces the classforwarddirectly (also covers the Phase 4 vanilla HF load — same remote-code class is reused fromsys.modules),NemotronFlashForCausalLMmodules (prior over-broad loop patched every submodule sharing the same config, breakingNemotronFlashModel.forward).Also widens
should_fix_lm_head_normto detect viacfg.model_type == "nemotron_flash"/architectures, matching howshould_fix_rotary_embeddingsis already written, so the check is robust to FSDP wrapping.The YAML thresholds (
kl_threshold=5e-3,hf_kl_threshold=1e1,trust_remote_code,no_check_resume,timeout_minutes=20) and the otherutils.py/infrastructure.py/model_init.py/nemotron_flash_lm_head_norm.pypieces were already present in the WIP commit on this branch (mirrors PR #1945 for the SFT sibling task #14).Test plan
1 passed, 25 warnings in 232.91s (0:03:52).