Skip to content

feat: DeepSeek V4 Flash support#2039

Open
khazic wants to merge 24 commits intoNVIDIA-NeMo:mainfrom
khazic:feat/deepseek-v4-flash
Open

feat: DeepSeek V4 Flash support#2039
khazic wants to merge 24 commits intoNVIDIA-NeMo:mainfrom
khazic:feat/deepseek-v4-flash

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented Apr 24, 2026

Status — DeepSeek V4 Flash full finetuning is now working ✅

End-to-end full finetuning of DeepSeek V4 Flash trains correctly on this branch. The forward-numerics gap that remained earlier in this PR is closed — five independent layer-parity bugs were uncovered by per-tensor dump bisection vs the DeepSeek inference reference (dsv4flash/inference/model.py) and all five are now fixed.

The loss curve at the bottom of this PR is from a real 43-layer full-finetune run with the full attention zoo (SWA + CSA + HCA) live — not the 4-layer smoke.

Many thanks to @HuiyingLi for the strong support landing this — the layer-parity work, and fixes to RoPE / HC / CSA / PP stage handling are what got this over the finish line.

Bugs fixed (post-bringup)

  1. RoPE convention — released DSV4-Flash uses INTERLEAVED-pair RoPE (view_as_complex on (2k, 2k+1) packing); HF transformers PR 45616 reused Llama-style rotate_half (pairs (d, d+rd/2)), which is the wrong dim-to-frequency mapping for these weights. Fixed via _apply_partial_rope_interleaved. Effect: kv_post_rope cos 0.866 → 0.999 after one block.
  2. Hyper-Connections post / comb formulaspost needs 2 * sigmoid(...) (no +eps); the trained weights expect range (0, 2), not (eps, 1+eps). comb needs softmax(dim=-1) + eps followed by col-norm-first Sinkhorn (iters - 1 alternating row/col passes); HF PR 45616's sigmoid+sinkhorn variant converges to a different distribution.
  3. Dual RoPE base on main attention — V4 has two RoPE bases (theta=10000 for compress_ratio==0, theta=160000 with YaRN for compress_ratio!=0), and the compress-rope must apply to the main attention Q/KV on compress_ratio>0 layers, not just the compressor sub-module. Reference: dsv4flash/inference/model.py:476-501 builds freqs_cis with compress_rope_theta whenever compress_ratio != 0.
  4. Layer-2 CSA pool mask leak — the pre-gathered [B, 1, S*topk, D] + F.pad(value=0) path made every query attend to every compressed column unconditionally, defeating the Indexer's per-query top-k selection and stacking every query's gathered slice side-by-side (non-causal cross-talk). Fixed: DeepseekV4Compressor.forward now returns (pooled, indexer_topk) with -1 for causally-invalid entries; DeepseekV4Attention.forward builds an explicit additive [B, 1, S, P_total] compressed-position mask (scatter zeros at valid topk, -inf elsewhere; deterministic p < (q+1) // ratio for compress_ratio==128) and concatenates onto the existing causal mask. Effect: L02_attn_out cos 0.741 → 0.994, block_02 cos 0.965 → 0.991.
  5. DSV4Model.forward not stage-aware — the original input_ids XOR inputs_embeds validation broke on PP stage ≥1, because the stage-trim pass nulls embed_tokens and the upstream activation arrives as a 4D HC tensor in the input_ids slot. Fixed: detect on_first_stage via self.embed_tokens is not None; on later stages treat the upstream as [B, S, hc_mult, hidden]; apply hc_head / norm only when present; pass layer_input_ids only on the first stage. Unblocks PP=2 / PP=4 runs of the validate and full yamls.

4-layer parity harness (HellaSwag prompt, num_hash_layers=2 compress_ratios=[0,0,4,128], PP=1 EP=8) after the 5-bug fix: final-logits cos 0.998 vs reference, top-1 token matches, every block cos ≥ 0.987.


Test roadmap

Bringing DeepSeek V4 Flash up incrementally on an 8×A100 single node with a compressed 4-layer smoke configuration, then scaling out.

Phase What is tested Model shape Status
P0 End-to-end forward/backward/PP/DCP health (full causal attention + score-based gate) 4 layers, num_hash_layers=0, full attn, pp=2 ep=4 dp=1, 100 train steps + 1 val step Complete — loss decreases, memory stable at ~32 GiB, no NaN/Inf
P2 Hash-routing gate forward 4 layers, num_hash_layers=2, DeepseekV4HashGate wired into the first 2 blocks Structural — gate class instantiated, tid2eid loads, input_ids threaded through DeepseekV4Model and the V4-aware PP forward; runs without crash
P3 Full 43-layer full-finetune 43 layers, real compress_ratios, num_hash_layers=3 Complete — see loss curve below

Things aligned with the reference (landed in this PR)

  • Attention port (HF transformers PR 45616) with two intentional divergences from HF that match the released DSV4-Flash inference reference: per-head non-learnable rsqrt on Q after wq_b, and Compressor overlap mode for compress_ratio==4. Both are documented inline.
  • Hash gate (DeepseekV4HashGate) wired in for layer_idx < num_hash_layers; input_ids threaded through the Block and PP forward; tid2eid stored as a persistent int64 buffer so FSDP can shard non-float tensors.
  • Dual-base rotary (rope_theta / compress_rope_theta) via two DeepseekV4RotaryEmbedding modules; both threaded through the Block to Attention so compress_ratio>0 layers get the correct phase on main Q/KV.
  • Inverse RoPE on attention output before the grouped output projection (matches reference).
  • Plain head_dim**-0.5 softmax_scale (not the V3 mscale*mscale factor).
  • sqrtsoftplus branch in the shared Gate.forward.
  • Clamped SwiGLU for routed experts (MoEConfig.swiglu_limit, new swiglu_clamped_deepep activation) — V4 Flash config ships swiglu_limit=10.0.
  • attn_sink (per-head learnable) appended as an extra softmax column on every attention call (eager_attention_with_sink).
  • Hyper-Connections with the corrected pre / post / comb formulas; col-norm-first Sinkhorn produces a doubly-stochastic comb matrix.
  • Compressor + Indexer (CSA) with overlap mode for compress_ratio==4 and non-overlap for compress_ratio==128; explicit additive compressed-position mask for the Indexer top-k selection.

Summary

Adds Automodel support for DeepSeek V4 Flash (deepseek-ai/DeepSeek-V4-Flash) — model definition, state dict adapter, pipeline-parallel forward, checkpoint loader dtype support, validate + HellaSwag recipes, unit tests, plus a suite of reference-aligned numerical corrections.

V4 diverges from V3/V3.2 in several load-bearing ways:

  • Attention: GQA with Q-LoRA and grouped O-LoRA (not MLA); per-layer SWA / CSA / HCA via compress_ratios; per-head learnable attention sink.
  • All-MoE stack: no dense MLP layers; first num_hash_layers use hash-clustering attention with a tid2eid routing table.
  • MoE routing: sqrtsoftplus + noaux_tc.
  • Clamped SwiGLU: routed experts gate/up clamped in FP32 (V4 Flash ships swiglu_limit=10.0).
  • Hyper-Connections (HC): every block maintains hc_mult=4 copies of the hidden state, mixed via a learned Sinkhorn router (hc_split_sinkhorn) before each sub-layer.
  • Multi-token prediction (MTP) layers.

Checkpoint format specifics the state dict adapter handles:

  • Routed-expert weights ship as FP4 e2m1fn packed two values per int8 byte, with per-row / 32-col FP8 e8m0fnu scales — unpacks on load, emits matching packed placeholders on to_hf so DCP's shape/dtype validation lines up with on-disk layout.
  • Shared experts + non-expert weights keep the standard FP8 e4m3fn 128×128 block path.
  • First num_hash_layers layers have no gate.bias on disk; the adapter reads num_hash_layers directly from the checkpoint's config.json and drops the corresponding bias keys before DCP load.
  • Indexer is a sibling of Compressor on disk with its own nested compressor; HF PR 45616 flattened them into Compressor.indexer.{ape,kv_norm,wgate,wkv,...}. Adapter renames indexer.compressor.{ape,norm,wgate,wkv} and indexer.{wq_b,weights_proj,...} to land at compressor.indexer.*.

Pipeline parallel: V4-specific pipeline_forward reproduces the non-PP forward per stage, threads input_ids through to hash-routing layers on stage 0, builds (cos, sin) from the rotary modules, threads position_embeddings_compress + rotary_compress to each block, applies build_causal_padding_mask + hc_head collapse, and _precompute_stage_shapes carries the extra hc_mult axis for inter-stage meta tensors. model.rotary_emb_compress is kept on every PP stage and model.hc_head on the last stage so the V4-specific forward survives module-pruning. DSV4Model.forward detects on_first_stage via self.embed_tokens is not None and treats the upstream as a 4D HC tensor on later stages.

Commit layout

  1. feat(deepseek_v4): add DeepSeek V4 Flash model, state dict adapter, tests, and recipes
  2. feat(checkpoint): recognize F8_E8M0 / F8_E5M2 dtypes in HF storage backport
  3. feat(pipelining): V4-aware pipeline_forward with hc_mult axis support
  4. fix(datasets): drop trust_remote_code for datasets>=4.0 in HellaSwag loader
  5. feat(v4): wire DeepseekV4HashGate for first num_hash_layers layers
  6. fix(v4): register HashGate.tid2eid as a buffer, not a Parameter
  7. fix(v4): apply inverse RoPE to attention output before wo_a
  8. fix(v4): drop YaRN mscale correction from attention softmax_scale
  9. feat(moe): clamped SwiGLU path for DeepSeek V4 routed experts
  10. feat(moe): add sqrtsoftplus branch to shared Gate.forward
  11. feat(v4): wire attn_sink into DeepseekV4Attention forward
  12. feat(v4): pure-torch HC Sinkhorn mixing for Block forward
  13. feat(v4): swap KAutomodel HC to HuggingFace transformers PR 45616 classes
  14. feat(v4): swap attention to HF PR 45616 + add released-checkpoint compat fixes
  15. fix(dsv4): five layer-parity bugs found by per-tensor dump bisection
  16. chore(dsv4): remove DSV4_DEBUG_DUMP per-tensor instrumentation

Test plan

  • Local unit tests under tests/unit_tests/models/deepseek_v4/ pass.
  • P1: full attention zoo (SWA + CSA + HCA) runs end-to-end on the 4-layer parity harness; layer-parity bugs fixed; per-block cos ≥ 0.987 vs reference.
  • P2 structural: num_hash_layers=2 path runs without crash with the hash gate instantiated; swiglu_clamped_deepep dispatch fires on all 8 ranks with limit=10.0.
  • HC Sinkhorn numerics: col-norm-first Sinkhorn produces a doubly-stochastic comb matrix.
  • Forward-numerics parity: 4-layer parity test final-logits cos 0.998 vs reference, top-1 token matches, every block cos ≥ 0.987.
  • P3 full finetuning (43-layer): full schedule full-finetune trains correctly — see loss curve below.
  • CI: linting + unit test suite.

Full-finetune loss curve

The plot below is from the 43-layer full-finetune run with all five bugs fixed and the full attention zoo (SWA + CSA + HCA) live.

image

khazic added 4 commits April 24, 2026 19:45
…ests, and recipes

Introduces Automodel support for DeepSeek V4 Flash (deepseek-ai/DeepSeek-V4-Flash).
V4 diverges from V3/V3.2 in several load-bearing ways:

* Attention: GQA with Q-LoRA and grouped O-LoRA (not MLA); per-layer
  sliding/compressed variants driven by compress_ratios.
* No dense MLP layers — every transformer block is MoE; first
  num_hash_layers use hash-clustering (HC) attention with a tid2eid
  routing table.
* MoE routing: sqrtsoftplus + noaux_tc.
* Multi-token-prediction (MTP) layers.

Routed expert weights ship as FP4 e2m1fn packed two values per int8
byte, with per-row / 32-col FP8 e8m0fnu scales — the state dict
adapter unpacks these at load time, and emits matching packed
placeholders on the to_hf path so DCP's shape/dtype validation lines
up with the on-disk checkpoint layout.  Shared experts and non-expert
weights keep the standard FP8 e4m3fn 128x128 block path already used
by V3.

The first num_hash_layers layers on disk have no gate.bias tensor
(hash-routing uses tid2eid, not an additive bias).  The adapter reads
num_hash_layers directly from the checkpoint's config.json and drops
the corresponding bias keys from the model-side state dict before
DCP load so the generic Gate (which always materializes the bias
buffer for training) does not trip 'missing key' on hash layers.

Adds:
* nemo_automodel/components/models/deepseek_v4/{config,layers,model,
  state_dict_adapter,__init__}.py
* examples/llm_finetune/deepseek_v4/{deepseek_v4_flash_validate,
  deepseek_v4_flash_hellaswag}.yaml
* tests/unit_tests/models/deepseek_v4/{test_dsv4_layers,
  test_dsv4_model_smoke,test_dsv4_state_dict_adapter}.py
* Registers DeepseekV4ForCausalLM / DeepseekV4Config in
  _transformers/registry.py.

Signed-off-by: khazic <khazzz1c@gmail.com>
…ckport

V4 Flash stores per-row FP8 scales with dtype string 'F8_E8M0'
(e8m0fnu, 1 byte per element).  The backported DCP HF storage
reader's DTYPE_MAP was missing this mapping, so _get_dtype fell back
to the default float32 (4 bytes).  At read time the backport sizes
tensors by length / dtype.itemsize, which then yields a quarter of
the real element count, causing:

  RuntimeError: shape '[4, 32]' is invalid for input of size 32

Map F8_E8M0 to torch.float8_e8m0fnu and F8_E5M2 to torch.float8_e5m2
so V4 FP8 scales (and any future E5M2 FP8 weights) parse correctly.

Signed-off-by: khazic <khazzz1c@gmail.com>
The generic hf_utils.pipeline_forward assumes HF-style models with a
model.rotary_emb module and (hidden_states, position_embeddings)
attention blocks.  V4 breaks all of those assumptions:

  - No rotary_emb module; RoPE comes from a shared freqs_cis buffer.
  - Hidden state is expanded to [B, S, hc_mult, dim] before the first
    block, preserved across blocks, then averaged + normed at the
    tail.
  - Blocks take (x, freqs_cis, attention_mask, padding_mask) kwargs.

Add DeepSeek V4 detection (config.model_type == 'deepseek_v4') and
two dedicated PP forwards — one for DeepseekV4Model's inner backbone
and one for DeepseekV4ForCausalLM's outer wrapper — that replicate
V4's non-PP forward per stage: the stage with embed_tokens does the
initial embed + hc_mult expand; the stage with self.norm averages
over hc_mult and applies the final RMSNorm; intermediate stages just
iterate their local layers with a stage-local freqs_cis.

Also update _precompute_stage_shapes so inter-stage meta tensors
carry the extra hc_mult axis until the final-norm stage folds it
back.  Other models keep the existing 3D meta shapes unchanged.

Signed-off-by: khazic <khazzz1c@gmail.com>
…loader

datasets 4.x removed the 'trust_remote_code' parameter entirely and
now rejects it with an error.  Detect its presence via inspect and
only pass it to load_dataset when still supported.  rowan/hellaswag
now also exposes a parquet export, so loading without the kwarg works
on all versions.

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

khazic added 8 commits April 24, 2026 21:33
…2 smoke)

Hash-routing layers bypass score-based topk and look up expert ids
from a static tid2eid table keyed by input token id.  The HashGate
class already existed but was never instantiated.  This change:

- Rewrites DeepseekV4HashGate.forward to match the generic Gate
  signature '(x, token_mask, cp_mesh) -> (weights, indices, aux_loss)'
  so the stock MoE module can call it interchangeably.
- Adds a per-forward set_input_ids() entry point + an ephemeral
  _pending_input_ids slot; DeepseekV4Block stashes the current batch's
  input_ids on the gate immediately before self.mlp(...) so the
  tid2eid lookup has what it needs.
- Adds no-op update_bias() and init_weights() so optimizer/sync
  walkers that probe MoE gates continue to work.
- In DeepseekV4Block.__init__, swaps self.mlp.gate with a
  DeepseekV4HashGate when layer_idx < num_hash_layers; the swap
  happens after MoE construction so experts/shared_experts/latent
  projections stay as built.
- Threads input_ids through DeepseekV4Model.forward and
  DeepseekV4Block.forward, and through the V4-aware pipeline_forward
  on the first PP stage (non-first stages pass None — hash layers are
  expected to be packed onto stage 0 where input_ids is available).

Also bumps the validate YAML's num_hash_layers to 2 so the first two
layers are hash-routing; with pp_size=2 on 4 layers, stage 0 owns
layers [0, 1] and therefore all hash layers.

Signed-off-by: khazic <khazzz1c@gmail.com>
nn.Parameter(int_tensor, requires_grad=False) is accepted at module
construction, but FSDP's fully_shard path re-wraps each Parameter via

  nn.Parameter(data.detach().requires_grad_(False))

and .requires_grad_() refuses anything that is not a floating-point
dtype:

  RuntimeError: only Tensors of floating point dtype can require gradients

tid2eid is a static lookup table (token id -> expert ids) that we
never train, so register it as a persistent buffer instead.  Buffers
are not sharded by FSDP as params and thus avoid the requires_grad_
call entirely.  Also switch the dtype from int32 to int64 to match
the V4 Flash checkpoint's on-disk I64 layout.

Signed-off-by: khazic <khazzz1c@gmail.com>
Reference model.py line 534:
    apply_rotary_emb(o[..., -rd:], freqs_cis, True)  # inverse=True
runs after sparse_attn and before self.wo_a, undoing the rotary
rotation that was applied to q_pe / kv_pe earlier.  Without this, the
output projection sees a representation that its trained weights do
not expect, and the logits collapse toward a small set of attractor
tokens — observed at step 0 on the V4 validate smoke:

  logits std=3.5  argmax_match=0%  mean_correct_logp=-19

which exactly matches the reported training loss of 19.

To multiply by conj(freqs_cis) we need the complex form, so also flip
the V4 model's freqs_cis_from_position_ids call (and the V4 PP inner
forward's matching call) to for_fused_rope=False, and pass
rope_fusion=False into apply_rotary_emb_qk so the q/k rotation goes
through the non-fused complex path as well.  Perf impact on the smoke
test is negligible; correctness now matches the reference forward.

Signed-off-by: khazic <khazzz1c@gmail.com>
Reference model.py line 464 defines the V4 attention scale as just

    self.softmax_scale = self.head_dim ** -0.5

with NO YaRN mscale correction — unlike V3, which folds
mscale*mscale into the scale when rope_scaling['factor'] > 1.

The V4 adapter had carried over V3's behaviour and, for the V4 Flash
config (factor=16, default mscale=1), ended up multiplying the scale
by ~1.63x.  Attention scores were therefore ~1.6x sharper than the
trained model expects, pushing each layer's output off the learned
distribution and producing step-0 logits std=3.5 with
argmax_match=0%.

Use the plain head_dim**-0.5 scale to match the reference.

Signed-off-by: khazic <khazzz1c@gmail.com>
Reference V4 Expert.forward applies gate/up clamping in FP32 before
silu(gate) * up whenever 'swiglu_limit > 0':

    gate = self.w1(x).float()
    up   = self.w3(x).float()
    if self.swiglu_limit > 0:
        up   = torch.clamp(up,   min=-swiglu_limit, max=swiglu_limit)
        gate = torch.clamp(gate,                     max=swiglu_limit)
    y = F.silu(gate) * up

V4 Flash ships with swiglu_limit=10.0 in its config; without the clamp
the bf16 routed-expert projections can spike into a regime the trained
weights were not calibrated for, contributing to the step-0 loss gap.

Adds:
* MoEConfig.swiglu_limit field (default 0.0 = disabled, preserves the
  existing weighted_bias_swiglu_impl path for every other model).
* swiglu_clamped_deepep(x, permuted_probs, limit) in components.moe.experts,
  dispatched from get_expert_activation_for_deepep when swiglu_limit>0.
* DeepseekV4Model wires config.swiglu_limit into MoEConfig so routed
  experts on every GroupedExperts* path pick up the clamp automatically.

Shared experts stay on the non-clamped path, matching the reference
(only routed Expert instances pass swiglu_limit in model.py MoE.__init__).

Signed-off-by: khazic <khazzz1c@gmail.com>
DeepSeek V4 uses score_func='sqrtsoftplus' — sqrt(softplus(x)) =
sqrt(log(1 + exp(x))) — for its MoE router.  Without this branch,
the generic Gate falls through to the sigmoid fallback and mis-scales
the routing scores.

Adds the 'sqrtsoftplus' elif between softmax_with_bias and the sigmoid
fallback, mirroring the official reference:

    scores = sqrt(softplus(scores.float()))
    if e_score_correction_bias is not None:
        scores_for_choice = scores + e_score_correction_bias
    indices = topk(scores_for_choice)
    weights = original_scores.gather(1, indices)

V4 is already configured with score_func='sqrtsoftplus' in
DeepseekV4Model moe_defaults, so this lights up automatically.

Signed-off-by: khazic <khazzz1c@gmail.com>
- Flip attn_sink to requires_grad=True so loaded checkpoint values
  actually participate in attention.
- Replace preprocess/attn_func/postprocess abstraction with an inline
  fp32 manual SDPA for V4, since that path forces is_causal=True and
  leaves no hook for a per-head bias.
- Append per-head attn_sink[h] as an extra softmax column and drop it
  after softmax: exp(s_ij) / (sum_k exp(s_ik) + exp(sink_h)). The sink
  slot contributes nothing to the output but lets attention mass drain
  away from real tokens for heads trained to attend to "nothing".

Closes most of the forward-loss gap on the 4-layer V4 Flash smoke
test: step 0 ~19 converges to ~11 within 15 steps (was stuck at ~19
with argmax_match=0% without attn_sink).

Signed-off-by: khazic <khazzz1c@gmail.com>
Replaces the mean-pool / broadcast-add approximations for hc_pre / hc_post
with a pure-torch port of the reference hc_split_sinkhorn kernel
(radixark/miles PR 1045, miles_plugins/models/deepseek_v4/ops/hyper_connection.py
+ kernel/sinkhorn.py).

- _hc_split_sinkhorn: pure-torch equivalent of the tilelang kernel. Slices
  the mixer tensor into (pre, post, comb) heads, applies the matching
  sigmoid / 2*sigmoid / row-softmax activations, and runs sinkhorn_iters
  alternating row/col normalizations to produce a doubly-stochastic comb
  matrix. eps placement matches the kernel byte-for-byte (added after
  row-norm on iter 0, inside the denominator on col-norm and subsequent
  iterations).

- hc_pre / hc_post: thin wrappers around the sinkhorn call plus the
  weighted-sum / residual-mix described in the reference's hc_pre_raw /
  hc_post_raw. The mixer itself runs under torch.no_grad() to match the
  reference's _HYPER_CONNECTION_MIXER_NO_GRAD flag; gradient only flows
  through the weighted sum of x / residual, not through the HC router
  parameters (they are frozen, loaded from the checkpoint as-is).

- hc_pre casts hc_fn / hc_scale / hc_base to float() locally so the mixer
  math stays in fp32 regardless of outer cast_model_to_dtype / FSDP
  mixed-precision policies. Mirrors the reference's per-param
  _keep_fp32 = True marker, but without requiring that hook here.

- DeepseekV4Block.forward: wires up hc_attn_{fn,scale,base} on the
  attention sub-block and hc_ffn_{fn,scale,base} on the FFN sub-block,
  using config.hc_sinkhorn_iters (default 20) and config.rms_norm_eps
  for the pre-mixer RMS normalization.

Impact on the 4-layer V4 Flash smoke test: step 0 forward loss drops
from ~19.3 (mean-pool approximation + attn_sink) to ~16.2. Remaining
~5-nat gap to the expected ~11 nats (log(vocab_size) for a 4-layer
pretrained slice) likely comes from the missing Compressed Sparse
Attention (CSA) pathway — see PR description for the running punch
list of known-gap items.

Signed-off-by: khazic <khazzz1c@gmail.com>
@HuiyingLi HuiyingLi force-pushed the feat/deepseek-v4-flash branch from e464385 to fb0a58f Compare April 25, 2026 07:00
HuiyingLi and others added 4 commits April 25, 2026 00:01
…sses

Replace the pure-torch free-function HC port (hc_pre / hc_post /
_hc_split_sinkhorn / hc_head_collapse in layers.py) with HF's two
self-contained nn.Module classes, ported verbatim from
transformers/src/transformers/models/deepseek_v4/modular_deepseek_v4.py:

  - DeepseekV4HyperConnection  (per-site attn_hc / ffn_hc mixer)
  - DeepseekV4HyperHead        (final collapse before the shared RMSNorm)

The previous port had three silent correctness divergences vs HF that
the earlier audit surfaced:

  1. comb used torch.softmax(...) + eps where HF uses sigmoid(...) + eps.
  2. post had a 2.0 prefactor and was missing + hc_eps.
  3. The mixer ran under torch.no_grad(), blocking gradient flow through
     the HC routing params.

Those are fixed by adopting HF's code unchanged. Sinkhorn iteration
count, +eps placement, and the fp32 defensive casts all match HF.

Swap mechanics:

  - layers.py: delete _hc_split_sinkhorn / hc_pre / hc_post /
    hc_head_collapse; add DeepseekV4HyperConnection + DeepseekV4HyperHead
    (unchanged from HF except for explicit kwarg plumbing in __init__
    instead of a config object, since layers.py stays config-agnostic).

  - model.py DeepseekV4Block: drop flat hc_attn_* / hc_ffn_* Parameters
    and _hc_param_shape; instantiate self.attn_hc + self.ffn_hc.
    forward() collapses with (pre.unsqueeze(-1) * x).sum(dim=2) and
    expands with post.unsqueeze(-1) * out.unsqueeze(-2) + torch.matmul(
    comb, x) — matching HF's DeepseekV4DecoderLayer.forward line-for-line.

  - model.py DeepseekV4Model: add self.hc_head = DeepseekV4HyperHead(...);
    replace h.mean(dim=2) at the final collapse site with self.hc_head(h).

  - model.py DeepseekV4ForCausalLM: expand
    _keep_in_fp32_modules_strict from just ["e_score_correction_bias"]
    to the 9 HC submodule keys (attn_hc.{fn,base,scale},
    ffn_hc.{fn,base,scale}, hc_head.hc_{fn,base,scale}) plus
    e_score_correction_bias, matching HF's _keep_in_fp32_modules_strict
    at modular_deepseek_v4.py lines 890-900.

  - state_dict_adapter.py: the checkpoint on disk uses flat HC names
    (layers.{i}.hc_attn_{fn,base,scale}, hc_head_{fn,base,scale}). Add
    rename rules that route those into HF's submodule tree:
      layers.X.hc_attn_{fn,base,scale}   -> model.layers.X.attn_hc.{fn,base,scale}
      layers.X.hc_ffn_{fn,base,scale}    -> model.layers.X.ffn_hc.{fn,base,scale}
      hc_head_{fn,base,scale}            -> model.hc_head.hc_{fn,base,scale}
    (note HF uses hc_* prefix inside HyperHead vs plain {fn,base,scale}
    inside HyperConnection). Inverse rules added to the to-HF table too.

End-to-end impact on the 8-GPU ep=4, pp=2 validate run (4 layers,
compress_ratios=[0,0,0,0], num_hash_layers=2, HellaSwag SFT):

  before (broken HC math): loss stuck at 16.2 for 5 steps, argmax_match=0
  after  (HF HC classes):  loss trajectory 16.88 -> 13.60 (step 9) ->
                           10.47 (step 19, crosses below log(vocab)=11.77)
                           -> 8.09 (step 49) -> 6.98 (step 99)

State dict loading is now correct for the HC paths. Remaining residual
loss (~7 at step 99 vs expected ~3-4 for a pretrained slice) tracks the
still-missing Compressor + Indexer + sliding-window + attn_sink fold-in
— addressed in the next PR.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…pat fixes

End-to-end attention swap to the HF transformers PR 45616 layout, plus the
fixes needed to actually load and run the released DSV4-Flash safetensors
under PP=2 EP=4 on the 4-layer validate harness.

Attention path (layers.py / model.py):
- Replace the legacy KAutomodel attention with the HF DeepseekV4Attention port:
  HF DeepseekV4GroupedLinear (`wo_a` as nn.Linear with per-group bmm forward),
  eager_attention_with_sink (per-head learnable sink folded as virtual key),
  partial RoPE on last qk_rope_head_dim, inverse-RoPE on attn output before
  the grouped output projection, dual-base rotary (rope_theta /
  compress_rope_theta) via two DeepseekV4RotaryEmbedding modules.
- Add a 4D causal+padding+sliding-window mask (build_causal_padding_mask)
  built once at the model/PP level; matches HF's create_sliding_window_causal_mask
  on every layer so attention sees the same band-diagonal pattern the
  released weights were trained under.
- Per-head non-learnable rsqrt on Q after wq_b (between wq_b and partial_rope).
  Present in the official inference reference, missing from HF PR 45616 —
  required to match the magnitudes the released wq_b weights expect.

Compressor / Indexer overlap mode (layers.py):
- compress_ratio==4 doubles wkv/wgate/ape feature dim to 2*head_dim.
- _overlap_transform reshapes [B, S/r, r, 2d] -> [B, S/r, 2r, d] so each
  compressed token aggregates 2*ratio raw tokens (current window plus the
  previous window's tail). _pool_windows takes an `overlap` flag.
- DeepseekV4Compressor uses overlap when ratio==4; DeepseekV4Indexer always
  overlaps (its compress_ratio is fixed at 4). compress_ratio==128 stays flat.

Pipeline-parallel wiring (pipelining/functional.py, pipelining/hf_utils.py):
- pipeline_forward_deepseek_v4 builds (cos, sin) from the rotary modules,
  threads position_embeddings + position_embeddings_compress + rotary_compress
  to each block, and uses build_causal_padding_mask + hc_head collapse.
- Keep model.rotary_emb_compress on every PP stage and model.hc_head on the
  last stage so the V4-specific forward survives module-pruning.

State-dict adapter (state_dict_adapter.py):
- Rename attn.attn_sink -> self_attn.sinks (HF renamed the param).
- Outer compressor: on-disk `compressor.norm` -> module `compressor.kv_norm`.
- Indexer is a sibling of compressor on disk with its own nested compressor;
  HF flattened them into Compressor.indexer.{ape,kv_norm,wgate,wkv,...}.
  Adapter renames `indexer.compressor.{ape,norm,wgate,wkv}` and
  `indexer.{wq_b,weights_proj,...}` to land at `compressor.indexer.*`.
- Add the new compressor/indexer projections to NON_QUANTIZED_PATTERNS so
  to_hf doesn't fabricate a `.scale` companion for the BF16 weights (only
  indexer.wq_b is FP8 on disk).

Checkpointing (checkpointing.py):
- Force the in-tree backport HuggingFaceStorageReader for `is_init_step=True`.
  The upstream reader silently drops F8_E8M0/F8_E5M2 dtypes in safetensors
  metadata decoding, leaving DCP with metadata=None on every rank.

Validate YAML:
- compress_ratios: [0,0,4,128] for the 4-layer truncated harness — exercises
  pure-SWA (layer 0,1), CSA with Indexer (layer 2), HCA without Indexer
  (layer 3) on every forward.

Result: 8 GPU PP=2 EP=4 with compress_ratios=[0,0,4,128] runs the full
attention zoo end-to-end (10 train + 1 val step), no errors.  Step-0 loss
~19, step-9 ~14.7 — same as the SWA-only baseline; the dominant ~14-nat
excess vs zero-shot expectation is upstream of the attention path and is
the next debugging target.

Two divergences from HF PR 45616 are intentional and documented inline
(per-head Q-norm and Compressor overlap mode).  Both are required to load
and run the released DSV4-Flash checkpoint; both are absent from the open
HF PR.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Closing the per-layer cosine gap vs the DeepSeek inference reference
(``dsv4flash/inference/model.py``).  Each fix below is a stand-alone
bug — they were uncovered by progressively dumping intermediate tensors
to ``$DSV4_DEBUG_DUMP/`` and bisecting where divergence enters.

End-to-end effect on the 4-layer parity test (HellaSwag prompt,
``num_hidden_layers=4 num_hash_layers=2 compress_ratios=[0,0,4,128]``,
PP=1 EP=8): final-logits cos vs reference 0.998, top-1 token matches.
Per-block cosines all >= 0.987.

================================================================
Bug 1 — RoPE convention: Llama-style ``rotate_half`` -> interleaved pairs
----------------------------------------------------------------
The released DSV4-Flash checkpoint encodes RoPE as INTERLEAVED pairs
(``(2k, 2k+1)``) — see reference ``apply_rotary_emb`` which uses
``view_as_complex`` on consecutive-pair packing.  HF transformers PR
45616/45643 reuses Llama-style ``rotate_half`` (pairs ``(d, d+rd/2)``).
Same algebra, different dim-to-frequency mapping — wrong on these
weights.

Fix: new helper ``_apply_partial_rope_interleaved`` does pair-rotation
on the last ``rope_head_dim`` slice using cos/sin on the FIRST half of
the Llama-style ``cat([freqs, freqs], -1)`` rotary output (the second
half is the duplicate the Llama helper needs and we don't).
``_apply_partial_rope`` now delegates to it.

Effect: kv_post_rope cos 0.866 -> 0.999 after one block.

================================================================
Bug 2 — Hyper-Connections ``post`` and ``comb`` formulas
----------------------------------------------------------------
HC's ``compute_weights`` had two formulas wrong vs reference
(``dsv4flash/inference/kernel.py:hc_split_sinkhorn_kernel`` 391-413):

  * ``post``: was ``sigmoid(...) + eps``; reference is ``2 * sigmoid(...)``
    (NO ``+eps``, AND a 2x prefactor).  Means the trained weights expect
    a range ``(0, 2)``, not ``(eps, 1+eps)`` — without this fix the post
    branch is half the magnitude during training.

  * ``comb``: was ``sigmoid(logit) + eps`` followed by sinkhorn; reference
    is ``softmax(logit, dim=-1) + eps`` followed by col-norm-first
    sinkhorn (``sum(dim=-2)`` BEFORE the iter loop, then ``iters - 1``
    alternating row/col passes).  Empirically the row-softmax converges
    to a much sharper distribution than sigmoid+sinkhorn on the same
    logits.

================================================================
Bug 3 — Dual RoPE bases: main attn at compress_ratio>0 layers
----------------------------------------------------------------
The released DSV4-Flash uses TWO RoPE bases:
  * theta=10000 (no YaRN) for layers with compress_ratio == 0
  * theta=160000 (with YaRN) for layers with compress_ratio != 0
and crucially, on compress_ratio>0 layers the compress-rope is applied
to the MAIN attention Q/KV, not just to the compressor sub-module.
Reference proof: ``model.py:476-501`` builds ``self.freqs_cis`` with
``compress_rope_theta`` whenever ``compress_ratio != 0``.

Before this fix DSV4Attention only consumed ``position_embeddings``
(theta=10000), so layer 2 / layer 3's main attention saw the wrong
phase.  Fix: pick ``position_embeddings_compress`` when
``self.compress_ratio`` is truthy and the caller supplied it.

================================================================
Bug 4 — Layer-2 attention compressed-pool mask leak (THIS SESSION)
----------------------------------------------------------------
On compress_ratio==4 layers (CSA + Indexer), reference uses ``sparse_attn``
with per-query ``topk_idxs`` whose ``-1`` entries are causally masked out
(``model.py:472-475``).  Our previous dense-attention port:

  1. Pre-gathered the compressor pool by indexer topk inside
     ``Compressor.forward``, returning ``[B, 1, S*topk, D]``.
  2. Concatenated that to ``full_kv`` and ran dense attention with
     ``F.pad(value=0.0)`` to extend the causal mask.

Two leaks:
  a. ``F.pad(value=0)`` makes every query attend to every compressed
     column unconditionally — defeats the indexer's per-query selection.
  b. The pre-gathered ``[S*topk, D]`` tensor stacks every query's
     gathered slice side-by-side; under (a), query q sees query q'-s
     gathered slice — pure non-causal cross-talk.

Fix:
  * ``DeepseekV4Compressor.forward`` now returns ``(pooled, indexer_topk)``
    where ``pooled`` is ``[B, 1, P_total, D]`` (no per-query gather) and
    ``indexer_topk`` is ``[B, S, K]`` with ``-1`` for causally-invalid
    entries.  Causal mask applied here matches reference (``mask =
    raw_topk >= arange(1, seq_len+1).unsqueeze(1) // ratio``).
  * ``DeepseekV4Attention.forward`` builds an EXPLICIT additive
    compressed-position mask of shape ``[B, 1, S, P_total]``:
      - if Indexer present: scatter zeros at the non-(-1) topk positions,
        -inf everywhere else.
      - no Indexer (compress_ratio==128): deterministic
        ``allowed = p < (q+1) // ratio`` (matches reference's
        ``get_compress_topk_idxs``).
    Then concat onto the existing causal mask along the last dim.
  * ``F.pad`` fallback kept only as defense-in-depth for callers that
    bypass the compressor branch.

Effect on the 4-layer parity test:
  A02_attn_o_pre_inv_rope cos: 0.807 -> 0.999
  L02_attn_out cos:            0.741 -> 0.994
  block_02 cos:                0.965 -> 0.991

================================================================
Bug 5 — DSV4Model.forward not stage-aware
----------------------------------------------------------------
Original forward required ``input_ids XOR inputs_embeds``, which fails on
PP stage 1+ where ``embed_tokens`` is nulled by the stage-trim pass and
the upstream activation arrives as a 4D HC tensor in the ``input_ids``
slot (DSv3 / DSv3.2 pattern — confirmed against
``DeepseekV3Model.forward``).

Fix:
  * Detect ``on_first_stage`` via ``self.embed_tokens is not None`` (NOT
    via input dtype — the trim pass nulls the attribute).
  * On non-first stages, treat ``input_ids`` (or the kwarg
    ``inputs_embeds``) as the upstream 4D ``[B, S, hc_mult, hidden]``
    activation; derive a 3D ``shape_ref`` for rotary / mask sizing.
  * ``hc_head`` and ``norm`` are applied only when present (last stage).
  * ``layer_input_ids`` is passed only on the first stage — hash-routing
    layers live there (``num_hash_layers <= layers per stage 0``).
  * Skip PP-trimmed slots (``layer is None``) in the block loop.

This unblocks PP=2 / PP=4 runs of the validate and full yamls (custom-
model pattern; ``patch_inner_model: false`` and ``patch_causal_lm_model:
false`` in the recipes).

================================================================
Debug instrumentation (always-off; gated on ``DSV4_DEBUG_DUMP``)
----------------------------------------------------------------
Added per-tensor dumps under ``\$DSV4_DEBUG_DUMP/rank{R}/`` matching the
reference's filename schema (``A{N}_q_post_rope.pt``, ``C{N}_pooled_*``,
``L{N}_attn_out``, ``embed``, ``block_NN``, ``hc_head_out``,
``post_norm``, ``logits.pt``).  Pair with ``tools/dsv4_state_parity/``
to bisect divergence.  Zero overhead when the env var is unset.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Strip the env-var-gated dump helpers from ``DeepseekV4Compressor``,
``DeepseekV4Attention``, ``DeepseekV4Block``, ``DeepseekV4Model``, and
``DeepseekV4ForCausalLM``.  They were added to bisect per-layer cosine
divergence vs the DeepSeek inference reference; that work is done.

The corresponding driver/diff scripts under
``tools/dsv4_state_parity/`` remain available for future debugging if
this code needs to be re-instrumented.

No functional change: every removed block was a no-op when
``$DSV4_DEBUG_DUMP`` was unset.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi HuiyingLi force-pushed the feat/deepseek-v4-flash branch from fb0a58f to 126585e Compare April 25, 2026 07:02
This recipe was a single-node infra-validation harness used to bring up
DSV4-Flash on a small cluster before the full-model run.  It overlaps in
purpose with ``deepseek_v4_flash_parity.yaml`` and ``deepseek_v4_flash.yaml``
(parity testing and full-scale training, respectively) — keeping it
makes the example list noisier without adding signal.

Removing the file from the branch.  Local copies can stay outside git
for personal smoke-testing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test ab2d7a0

HuiyingLi
HuiyingLi previously approved these changes Apr 25, 2026
Copy link
Copy Markdown
Contributor

@HuiyingLi HuiyingLi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much!!

@HuiyingLi HuiyingLi enabled auto-merge (squash) April 25, 2026 08:37
The CICD ``linting`` job runs ``ruff format --check``, not just
``ruff check`` — three files I touched earlier still needed to be
reformatted plus four tests that the previous lint-fix commit left
in a not-quite-formatted state (whitespace, line wrapping).

No semantic changes; ``ruff check`` is also clean and the 375 tests
across the touched files still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 11d0b1a

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 11d0b1a

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label Apr 25, 2026
sharonyu-115 added a commit to sharonyu-115/RL that referenced this pull request Apr 25, 2026
…sformers to 5.5

Vendor in-flight DSV4-Flash support from NVIDIA-NeMo/Automodel#2039.

.gitmodules + submodule: Automodel -> khazic/Automodel_lao @
feat/deepseek-v4-flash, gitlink at ab2d7a08 (PR NVIDIA-NeMo#2039 head, 24
commits). The PR registers DeepseekV4ForCausalLM natively in
nemo_automodel/_transformers/registry.py, ships FP4-expert + FP8-
attention loaders, and a state_dict_adapter with
convert_single_tensor_to_hf so refit-to-vLLM works through the
existing dtensor_params_generator path.

transformers 5.3.0 -> 5.5.0 (pyproject + uv.lock). Required
because Automodel main forwarded past our previous pin and now
imports transformers.models.gemma4.modeling_gemma4 unconditionally
in components/distributed/parallelizer.py:49 (gemma4 ships in
transformers >= 5.5.0.dev). vLLM stays pinned to the local DSV4
wheel; the override on transformers in [tool.uv] supersedes the
wheel's <5 metadata declaration. Runtime smoke on the rebaked sqsh
confirms vllm._C loads clean against transformers 5.5 and DSV4
returns "Paris." end-to-end via the standard NeMo-RL vllm_worker
init path.

Signed-off-by: Shuang Yu <shuangy@nvidia.com>
khazic added a commit to khazic/Automodel_lao that referenced this pull request Apr 25, 2026
The {download} directive on the recipe yaml fails the Sphinx build
with `download.not_readable` because
examples/llm_finetune/deepseek_v4/deepseek_v4_flash_hellaswag.yaml
is added by the model PR (NVIDIA-NeMo#2039), which has not yet landed on main.
Use a plain GitHub link until NVIDIA-NeMo#2039 merges; a follow-up can switch
back to {download} once the file is on main.

Signed-off-by: khazic <khazzz1c@gmail.com>
HuiyingLi added a commit that referenced this pull request Apr 26, 2026
…2054)

* docs(llm): drop validate-yaml reference from DeepSeek V4 Flash guide

Removes the validate-yaml bullet under "Launch Training" and the
"Quick infrastructure validation" subsection.  The validate harness
is an internal smoke-test config, not a user-facing finetune recipe;
the guide should advertise only the HellaSwag recipe.

Follow-up to #2053 (the original change was force-pushed after the
PR had already merged, so the deletion did not land on main).

Signed-off-by: khazic <khazzz1c@gmail.com>

* docs(llm): add DeepSeek V4 Flash to README + model-coverage index

Mirrors the per-model rollout pattern used for MiniMax-M2.7 (#1785):
news entry at the top of the README, a dedicated model-coverage page
under deepseek-ai/, and registration of the new page in the LLM index
(architecture table + toctree).

- README.md                                            (news entry)
- docs/model-coverage/llm/deepseek-ai/dsv4-flash.md    (new)
- docs/model-coverage/llm/index.md                     (table + toctree)

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs(llm): use plain link for hellaswag yaml until model PR lands

The {download} directive on the recipe yaml fails the Sphinx build
with `download.not_readable` because
examples/llm_finetune/deepseek_v4/deepseek_v4_flash_hellaswag.yaml
is added by the model PR (#2039), which has not yet landed on main.
Use a plain GitHub link until #2039 merges; a follow-up can switch
back to {download} once the file is on main.

Signed-off-by: khazic <khazzz1c@gmail.com>

---------

Signed-off-by: khazic <khazzz1c@gmail.com>
Co-authored-by: Huiying Li <willwin.lee@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 2f979fb

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 26, 2026

/ok to test 2f979fb

@HuiyingLi, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants