[nemotron_h] respect _no_reinit flag on dt_bias and out_proj.weight#45591
[nemotron_h] respect _no_reinit flag on dt_bias and out_proj.weight#45591vai-minzhou wants to merge 3 commits intohuggingface:mainfrom
Conversation
_init_weights() on `NemotronHPreTrainedModel` unconditionally overwrites
`dt_bias` (random `inv_softplus(dt)`) and `out_proj.weight` (kaiming_uniform
scaled by 1/sqrt(n_layer)) every time it is invoked on a mamba block.
It sets `module.dt_bias._no_reinit = True` after the copy, but the flag is
never checked by either code path (only the Linear-bias branch reads it).
On transformers>=5.0, `_init_weights` is triggered a second time after
`from_pretrained()` has loaded the checkpoint (the post-load safety pass
that initializes tensors staying on `meta`). For `NemotronHForCausalLM`
that silently overwrites the checkpoint values for `dt_bias` and
`out_proj.weight` with fresh random draws. The model then outputs
repetitive stop-word streams like ` and and and and ,` for any input.
Minimal repro with any Nemotron-H checkpoint:
from transformers import AutoConfig, AutoModelForCausalLM
from safetensors.torch import load_file
import json, pathlib
path = ".../NVIDIA-Nemotron-Cascade-2-30B-A3B-BF16" # or Nano
cfg = AutoConfig.from_pretrained(path); cfg._attn_implementation='eager'
m = AutoModelForCausalLM.from_pretrained(path, config=cfg, torch_dtype='bfloat16')
idx = json.loads((pathlib.Path(path) / 'model.safetensors.index.json').read_text())['weight_map']
k = 'backbone.layers.0.mixer.dt_bias'
on_disk = load_file(f'{path}/{idx[k]}')[k]
in_mem = m.backbone.layers[0].mixer.dt_bias
print((on_disk.float() - in_mem.float().cpu()).abs().max()) # ~26.8
This patch makes `_init_weights` honour `_no_reinit` on both `dt_bias` and
`out_proj.weight` (the only two params that re-init unconditionally), and
sets `_no_reinit = True` on `out_proj.weight` after the initial kaiming
scale so a second pass is a no-op. Ordinary fresh-init training is
unaffected; only the second invocation becomes idempotent.
Signed-off-by: Min Zhou <minzhou@virtueai.com>
|
Hey, I'm not sure about this PR! We already have the |
Per @Rocketknight1's review: replace the ad-hoc `_no_reinit` flag with the existing `_is_hf_initialized` flag that `from_pretrained` already sets on checkpoint-loaded parameters. Guard each Mamba2 init target (A_log / D / dt_bias) and the residual-scaled `out_proj.weight` independently, so parameters restored from a checkpoint survive any subsequent `_init_weights` pass.
|
Thanks for the review! You're right —
Left the pre-existing Repro that motivated the original bug: with the old code, loading a finetuned NemotronH checkpoint left |
There was a problem hiding this comment.
The code is still referencing _no_reinit here - can we just replace with _is_hf_initialized everywhere in the file, or will that cause problems?
|
Just pushed |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: nemotron_h |
Summary
NemotronHPreTrainedModel._init_weightsunconditionally overwrites two trained parameters every time it is invoked:NemotronHMamba2Mixer.dt_bias— reset to a freshinv_softplus(random dt)draw{…}.out_proj.weight— reset to a kaiming-uniform scaled by1/sqrt(num_hidden_layers)It sets
module.dt_bias._no_reinit = Trueafter the copy, but that flag is only checked for thenn.Linear.biasbranch of the same function — never fordt_biasitself, andout_proj.weightdoesn't set the flag at all.On
transformers>=5.0,_init_weightsruns a second time afterfrom_pretrainedhas finished loading the checkpoint (the post-load pass that initialises tensors still onmeta). ForNemotronHForCausalLMthat silently overwrites the on-disk values fordt_biasandout_proj.weightwith fresh random ones, while all other tensors keep their trained values.The resulting model outputs repetitive filler streams like
and and and , and and ,for any input — sanity is preserved only when loading through vLLM (which bypasses_init_weights) or via an older transformers release.Reproduction
Prompting
"Hello, how are you? I am"on an unpatched load returns' and' ' in' ' the' ' first' ','as top-5 next tokens — a symptom of Mamba2 with randomiseddt_biasand mis-scaledout_proj. After the patch, trained values are preserved and the model generates normally.The fix
Both changes live in
NemotronHPreTrainedModel._init_weights:dt_biasbranch: early-return ifdt_bias._no_reinitis already set (the flag is set at the end of the current branch, so the first pass initialises normally, the second pass becomes a no-op).out_proj.weightbranch: skip whenp._no_reinitis set, and setp._no_reinit = Trueafter the initial kaiming scale so a second invocation is a no-op.Fresh-init training is unaffected — only the second (post-load) invocation is made idempotent. Same edit is mirrored into
modular_nemotron_h.pyandmodeling_nemotron_h.py.Test plan
|on_disk - in_mem|.max()for layer-0 dt_bias ≈ 26.8, next-token logits return stop-word garbage.tests/models/nemotron_h/— no behaviour change for fresh-init, only the idempotence of the re-init pass changes.Please let me know if you'd like the fix to take a different shape (e.g. short-circuit
_init_weightsentirely when the module's parameters are all materialised, or move the guard to a shared utility inmodeling_utils). Happy to adjust.