Skip to content

Add Gemma 4 attention features: QK/value norms, shared KV, partial RoPE, embedding scale, logit softcap#492

Open
jlamypoirier wants to merge 12 commits intomainfrom
worktree-gemma
Open

Add Gemma 4 attention features: QK/value norms, shared KV, partial RoPE, embedding scale, logit softcap#492
jlamypoirier wants to merge 12 commits intomainfrom
worktree-gemma

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented Apr 25, 2026

Summary

Language model

  • embedding_scale (LanguageModelEmbeddingsConfig): multiplicative scale applied to word embeddings after lookup. Gemma 4 uses sqrt(hidden_size). Zero overhead for the default value of 1.0 via a runtime branch. Necessary as a runtime op (not a weight init change) because tied embeddings share the weight with the LM head — baking the scale into weights would also scale logits.
  • final_logit_softcap (LanguageModelHeadConfig): applies tanh(logits / cap) * cap before the loss. Gemma 4 uses cap=30. Forward and backward are each @torch.compile-decorated for op fusion. Gradient propagates through the Jacobian (1 - (softcapped / cap)²) before the output-linear backward.

Normalization

  • FixedRMSNormConfig / FixedRMSNormalization: no-weight RMS normalization. Triton kernel extended with has_weight: tl_constexpr to skip weight load/multiply when weight=None; torch fallback uses torch.rms_norm(..., weight=None).
  • post_mixer_normalization / post_mlp_normalization (DecoderBlockConfig): optional normalization applied to mixer/MLP outputs before the residual add. Gemma 4 applies RMSNorm at both positions.
  • pre_mixer_normalization / pre_mlp_normalization (DecoderBlockConfig): independent overrides for the pre-norm at each sub-layer. Both default to normalization when unset, enabling different norm types per sub-layer (e.g. none for pre-MLP in certain Gemma 4 block variants).

Attention

  • query_norm / key_norm (AttentionConfig): optional per-head RMSNorm applied to query and key vectors before RoPE. Gradient handled via a local autograd subgraph inside wrap_forward_backward.
  • value_norm (AttentionConfig): optional per-head normalization applied to value projections before attention. Gemma 4 uses fixed_rms_norm (no learnable weight).
  • shared_key_value (AttentionConfig): single key projection reused as value. Gradients from both key and value paths are summed back to the projection in the backward pass.
  • In-place rotary fix: triton_rotary_ wrote results in-place, silently corrupting the saved norm output when query_norm was active (both tensors shared storage via .detach()). Added output_ptr to the Triton kernel and inplace_query flag through the rotary layer so the query gets a fresh allocation when a query_norm context is live.
  • ProportionalRotaryConfig / ProportionalRotary: partial RoPE where only the first partial_rotary_factor fraction of head dimensions receive positional encoding (NoPE for the rest, via zero angle scales). Gemma 4 global-attention layers use partial_rotary_factor=0.5.

MoE

  • HybridMoEMLPConfig / HybridMoEMLP: new MLP variant combining an always-active dense MLP with top-K routed experts. Each branch has optional pre/post norms (dense_pre_norm, dense_post_norm, moe_pre_norm, moe_post_norm). Gemma 4 uses this layout with separate intermediate sizes for the dense and expert paths.

Checkpoint converter

  • Gemma 4 HuggingFace checkpoint converter (fast_llm/models/gpt/conversion/gemma4.py): full import/export support for the Gemma 4 text model family including sliding-window and full-attention pattern blocks, per-head norms, partial RoPE, hybrid MoE blocks, and tied embeddings.
  • attention_k_eq_vshared_key_value: Gemma 4 26B-A4B sets attention_k_eq_v=True for full-attention layers; the converter maps this to AttentionConfig.shared_key_value=True and routes through a single k_proj weight (no v_proj).
  • MoE weight layout: HF stores expert weights as [num_experts, out, in] batched tensors; the converter reshapes to Fast-LLM's flat [num_experts * out, in] layout for gate_up_proj and handles the additional transpose for down_proj.
  • use_bidirectional_attention: exported as None (Fast-LLM is text-only; bidirectional attention for vision tokens is not implemented).

Not yet implemented

  • Per-Layer Embeddings (PLE): Gemma 4 feeds an auxiliary per-layer embedding signal (from a separate 262k-entry table) into each decoder block. Exported as hidden_size_per_layer_input: 0 to disable the feature until it is implemented in Fast-LLM. Round-tripping a real Gemma 4 checkpoint will lose PLE weights. Follow-up work needed.

Tests

  • tests/layers/test_embedding.py: 3 base cases × 4 variants = 12 cases.
  • tests/layers/test_lm_head.py: adds final_logit_softcap=2.0 case.
  • tests/layers/test_attention.py: rewritten as a single parametrized suite — 5 base cases × 6 norm variants + rotary cases + shared_key_value cases = 160 cases covering forward reference, packing equivalence, and flash equivalence.
  • tests/layers/test_decoder_block.py: 4 cases (no post-norms, post-mixer only, post-MLP only, both).
  • tests/layers/test_mlp.py: hybrid MoE cases added. Test dimensions bumped to 128 to satisfy Triton block-size compile-time assertions under FAST_LLM_SKIP_TRITON_AUTOTUNE.
  • tests/layers/test_rotary.py: rewritten as a single parametrized suite — default, big-theta, llama3, yarn, 2d, and proportional variants across head sizes and sequence lengths = 54 cases. Fixed a bug where triton_rotary_ (in-place) corrupted the reference input; now clones query before forward.
  • tests/models/: 17 model tests pass for gemma4 (simple, bf16, fp16, checkpoint, resume, conversion, round-trip, load-pretrained, huggingface, frozen-weights, dtype variants).
  • tests/models/test_hf_roundtrip.py: new test_hf_roundtrip[gemma4] using the real google/gemma-4-26B-A4B config with scaled-down dims to exercise the full import/export cycle.

Test plan

  • pytest -v -n 8 tests/layers/test_rotary.py — 54 passed
  • pytest -v -n 8 tests/layers/test_attention.py — 160 passed
  • pytest -v tests/layers/test_embedding.py — 12 passed
  • pytest -v tests/layers/test_lm_head.py — passed
  • pytest -v tests/layers/test_decoder_block.py — 4 passed
  • pytest -v -n 8 tests/layers/test_mlp.py — 8 passed
  • pytest -v -n 8 --models gemma4 tests/models/ — 18 passed (17 + roundtrip)
  • pytest -v -n 8 tests/ — full suite (running)

🤖 Generated with Claude Code

@jlamypoirier jlamypoirier changed the title Add embedding_scale and final_logit_softcap (Gemma 4 prep) Add QK norm, post-block norms, embedding scale, and logit softcap (Gemma 4 prep) Apr 27, 2026
@jlamypoirier jlamypoirier changed the title Add QK norm, post-block norms, embedding scale, and logit softcap (Gemma 4 prep) Add Gemma 4 attention features: QK/value norms, shared KV, partial RoPE, embedding scale, logit softcap Apr 28, 2026
jlamypoirier and others added 12 commits April 30, 2026 01:40
- `LanguageModelEmbeddingsConfig.embedding_scale`: multiplicative scale
  applied to word embeddings after lookup (Gemma 4 uses sqrt(hidden_size)).
  Zero overhead for the default value of 1.0 via a compile-time branch in
  the @torch.compile-decorated _forward.
- `LanguageModelHeadConfig.final_logit_softcap`: applies
  tanh(logits / cap) * cap before the loss. Forward and backward are
  each wrapped in @torch.compile for op fusion. Gradient back-propagates
  through the Jacobian (1 - (softcapped / cap)^2) before the output
  linear backward.
- New test_embedding.py: generic parametrized embedding layer test
  covering scale, dtype, full_precision_residual, position embeddings,
  and padding (3 base cases x 4 variants).
- Adds final_logit_softcap case to test_lm_head.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- AttentionConfig: add query_norm and key_norm fields (NormalizationConfig | None)
- Attention: apply QK norms before RoPE in forward/backward, with wrap_forward_backward-compatible gradient handling
- DecoderBlockConfig: add post_mixer_normalization and post_mlp_normalization fields
- DecoderBlock: apply post-norms to mixer/MLP outputs before residual add
- Tests: test_qk_norm (4 cases) and test_post_norms (4 cases)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…tion

Replaces test_qk_norm and test_attention_implementations with a single
parametrized test_attention that covers all combinations of causal/noncausal/
window attention × QK norm variants, checking packing equivalence (backup, float32,
precise gradient comparison) and flash vs backup equivalence (bfloat16, output only).
Also removes the AttentionConfig case from test_varlen since attention is now
covered in test_attention.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The independent reference uses F.linear + torch.rms_norm + a plain per-document
einsum loop to verify the full forward pass, catching any norm applied to the wrong
slice or attention computed incorrectly — not just internal consistency.

Also adds [20, 32, 10, 11, 9, 18] (100 tokens) to the length parametrize set to
cover the flash attention range that the old test_attention_implementations had.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Add MQA (kv_heads=1), MHA (kv_heads=heads), rotary, and query/key norm
variants to the parametrized attention test, bringing it to 96 cases.
The independent reference (plain F.linear + per-head einsum loop) now
covers all combinations. Run entirely on GPU with TF32 disabled via a
_no_tf32() context manager to keep precision tight without CPU-Triton
conflicts.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds ProportionalRotaryConfig/ProportionalRotary for partial RoPE
(partial_rotary_factor<1), where NoPE dimensions pass through via zero
angle scales.  Replaces the ad-hoc test_rotary with a single
parametrized test covering default, big-theta, llama3, yarn, 2d, and
proportional variants across multiple head sizes and sequence lengths.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds FixedRMSNormConfig/FixedRMSNormalization, a no-weight RMS norm
with triton (has_weight constexpr) and torch paths.  Wires it into
AttentionConfig as value_norm (NormalizationConfig|None), applying
fixed-scale RMS norm to value projections per head.  Also adds
shared_key_value, which uses a single key projection reused as value
with gradients summed back in the backward pass.

Extends test_attention with value_norm and all_norms norm variants
across all base cases, plus a shared_key_value case family.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
triton_rotary_ wrote results in-place, silently corrupting the saved
norm output when query_norm was active (both tensors shared storage via
.detach()). Add output_ptr to the Triton kernel and inplace_query flag
through the rotary layer so the query gets a fresh allocation when a
query_norm context is live.

Also rename *_norm_ctx -> *_norm_context for consistency.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add HybridMoEMLPConfig/HybridMoEMLP: always-active dense MLP + top-K
  routed experts with optional per-path pre/post norms
- Add pre_mixer_normalization and pre_mlp_normalization to
  DecoderBlockConfig so norm_1 and norm_2 can be configured independently;
  normalization remains the shared default when either is unset
- Add tests/layers/test_mlp.py covering HybridMoEMLP composition and norms

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds import/export support for Gemma 4 text models (gemma4 format):
- Pattern decoder with alternating sliding-window and global attention
- Per-head query/key/value norms, post-attention and post-MLP norms
- Partial RoPE for global attention layers
- Hybrid dense+MoE blocks with pre/post norms
- Tied embeddings with sqrt(hidden_size) embedding scale
- Logit softcapping

Exports `hidden_size_per_layer_input: 0` to disable Per-Layer Embeddings
(PLE) in the native HuggingFace model; TODO to implement PLE in Fast-LLM.

Adds `gemma4` model testing config and registers the format with
GPTModelConfig and AutoGPTHuggingfaceCheckpointHandler.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rip test

- Map HF attention_k_eq_v=True to AttentionConfig.shared_key_value=True for
  full-attention layers in the 26B-A4B model (K projection is reused as V;
  only a single k_proj weight exists, no v_proj)
- Add Gemma4MoELayer1Converter / Gemma4MoELayer2Converter to correctly reshape
  batched expert weights: gate_up_proj [E,2I,H] ↔ [E*2I,H] and
  down_proj [E,H,I] ↔ [E*I,H] (permute+reshape)
- Export use_bidirectional_attention=None (text-only; vision tokens not supported)
- Add test_hf_roundtrip[gemma4] using google/gemma-4-26B-A4B config

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
test_rotary: triton_rotary_ modifies query in-place, so clone before
calling forward to avoid feeding the already-rotated tensor to the
reference implementation (which caused a double-rotation mismatch).

test_mlp: increase _NUM_TOKENS/_HIDDEN_SIZE/_INTERMEDIATE_SIZE from 16/64/32
to 128/128/128 so dimensions satisfy the block_size_row=128, block_size_col=128
compile-time assertions in output_sparse_matmul_kernel when
FAST_LLM_SKIP_TRITON_AUTOTUNE is set.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant