From 901cfede63eab28867d51e41cfee7a53bdd59d1e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 24 Apr 2026 20:33:38 -0400 Subject: [PATCH 01/14] Add embedding_scale and final_logit_softcap for Gemma 4 support - `LanguageModelEmbeddingsConfig.embedding_scale`: multiplicative scale applied to word embeddings after lookup (Gemma 4 uses sqrt(hidden_size)). Zero overhead for the default value of 1.0 via a compile-time branch in the @torch.compile-decorated _forward. - `LanguageModelHeadConfig.final_logit_softcap`: applies tanh(logits / cap) * cap before the loss. Forward and backward are each wrapped in @torch.compile for op fusion. Gradient back-propagates through the Jacobian (1 - (softcapped / cap)^2) before the output linear backward. - New test_embedding.py: generic parametrized embedding layer test covering scale, dtype, full_precision_residual, position embeddings, and padding (3 base cases x 4 variants). - Adds final_logit_softcap case to test_lm_head.py. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/layers/language_model/config.py | 12 ++++++++++++ fast_llm/layers/language_model/embedding.py | 4 ++++ fast_llm/layers/language_model/head.py | 15 +++++++++++++++ tests/layers/test_embedding.py | 6 ++++++ tests/layers/test_lm_head.py | 8 ++++++++ 5 files changed, 45 insertions(+) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4a8efdab6..bde33f297 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -79,6 +79,12 @@ class LanguageModelEmbeddingsConfig(BlockConfig): " Affects RNG for initialization and dropout.", hint=FieldHint.performance, ) + embedding_scale: float = Field( + default=1.0, + desc="Multiplicative scale applied to word embeddings after lookup.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) @property def layer_class(self) -> "type[LanguageModelEmbedding]": @@ -119,6 +125,12 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + final_logit_softcap: float | None = Field( + default=None, + desc="Soft-cap applied to logits before loss: logits = tanh(logits / cap) * cap.", + hint=FieldHint.architecture, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) prediction_heads: int = Field( default=1, desc="Prediction heads.", diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index f01d6ad73..9574bb15c 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -79,6 +79,7 @@ def _forward( position_ids: torch.Tensor | None, mask_inputs: bool, embedding_map: torch.Tensor, + embedding_scale: float, ) -> torch.Tensor: group = self._parallel_dim.group if self._vocab_parallel: @@ -132,6 +133,8 @@ def _forward( (embedding_map,), input_[: embedding_map.size(0)], accumulate=True ) + if embedding_scale != 1.0: + embeddings = embeddings * embedding_scale with set_generator( self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): @@ -162,6 +165,7 @@ def forward( # Masking is needed with image tokens or padding. input_ is not None or kwargs[LanguageModelKwargs.num_tokens] < kwargs[LanguageModelKwargs.token_dim].size, embedding_map, + self._config.embedding_scale, ) self._debug(out, None, (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._hidden_dim), kwargs) return out diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 95be18035..22c750082 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -30,6 +30,16 @@ OUTPUT_WEIGHTS = "output_weights" +@torch.compile +def _softcap(logits: torch.Tensor, cap: float) -> torch.Tensor: + return torch.tanh(logits / cap) * cap + + +@torch.compile +def _softcap_backward(grad: torch.Tensor, softcapped: torch.Tensor, cap: float) -> torch.Tensor: + return grad * (1.0 - (softcapped / cap) ** 2) + + class LanguageModelHead[ConfigType: LanguageModelHeadConfig](Block[ConfigType]): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). @@ -249,6 +259,8 @@ def _logits_loss_forward_backward_partial( group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) + if self._config.final_logit_softcap is not None: + logits = _softcap(logits, self._config.final_logit_softcap) self._debug( logits, f"logits{"" if self._config.cross_entropy_splits == 1 else f"_{split_index}"}", @@ -273,6 +285,9 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) + if grad is not None and self._config.final_logit_softcap is not None: + grad = _softcap_backward(grad, logits, self._config.final_logit_softcap) + return sum(losses_) if losses_ else None, ( output_parallel_linear_backward(grad, context) if self.training else None ) diff --git a/tests/layers/test_embedding.py b/tests/layers/test_embedding.py index b11d21ecc..2177bb63b 100644 --- a/tests/layers/test_embedding.py +++ b/tests/layers/test_embedding.py @@ -21,6 +21,7 @@ @dataclasses.dataclass class EmbeddingTestConfig: name: str + embedding_scale: float = 1.0 compute_dtype: DataType = DataType.float32 full_precision_residual: bool = False with_position_embeddings: bool = False @@ -33,6 +34,7 @@ def residual_dtype(self) -> torch.dtype: def get_config(self) -> GPTModelConfig: embeddings: dict = { "vocab_size": VOCAB_SIZE, + "embedding_scale": self.embedding_scale, "full_precision_residual": self.full_precision_residual, } if self.with_position_embeddings: @@ -88,6 +90,9 @@ def get_reference_output(self, layer: LanguageModelEmbedding, kwargs: dict) -> t if mask_inputs: embeddings = embeddings * token_mask.unsqueeze(-1) + if self.embedding_scale != 1.0: + embeddings = embeddings * self.embedding_scale + return embeddings.to(dtype=self.residual_dtype) @@ -103,6 +108,7 @@ def get_reference_output(self, layer: LanguageModelEmbedding, kwargs: dict) -> t ("float32", {}), ("bfloat16", {"compute_dtype": DataType.bfloat16}), ("full_precision_residual", {"full_precision_residual": True}), + ("embedding_scale", {"embedding_scale": 2.0}), ] diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index aa50fbb5e..5832fea3f 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -28,6 +28,7 @@ class LMHeadTestConfig: z_loss: bool | float = False grpo_loss: bool | float = False logits_scale_factor: float = 1.0 + final_logit_softcap: float | None = None compute_dtype: DataType = DataType.float32 full_precision_residual: bool = False loss_masking: bool = False @@ -53,6 +54,8 @@ def get_config(self) -> GPTModelConfig: "cross_entropy_splits": self.num_splits, "prediction_heads": self.prediction_heads, } + if self.final_logit_softcap is not None: + head_config["final_logit_softcap"] = self.final_logit_softcap losses = {} if self.label_loss is not False: losses["label"] = {"type": "label"} @@ -167,6 +170,10 @@ def get_reference_outputs( hidden = torch.rms_norm(input_.to(normalization_weight.dtype), input_.shape[-1:], normalization_weight, 1e-5) logits = torch.nn.functional.linear(hidden, logit_weight).float() + if self.final_logit_softcap is not None: + cap = self.final_logit_softcap + logits = torch.tanh(logits / cap) * cap + if self.logits_scale_factor is not None: logits = logits * self.logits_scale_factor @@ -248,6 +255,7 @@ def _add_configs(base_name: str, **kwargs): _add_configs("bfloat16", compute_dtype=DataType.bfloat16) _add_configs("full_precision_residual", full_precision_residual=True) _add_configs("logit_scaling", logits_scale_factor=5.0) +_add_configs("final_logit_softcap", final_logit_softcap=2.0) _add_configs("tied_embedding_weight", tied_embedding_weight=True) _add_configs("multi_token_prediction", prediction_heads=2) _add_configs("label_loss", label_loss=True) From 0f2ff6fe5c908effd7d07ce84177f91a43f612db Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 27 Apr 2026 17:04:09 -0400 Subject: [PATCH 02/14] Add QK norm and post-mixer/MLP normalization for Gemma 4 support - AttentionConfig: add query_norm and key_norm fields (NormalizationConfig | None) - Attention: apply QK norms before RoPE in forward/backward, with wrap_forward_backward-compatible gradient handling - DecoderBlockConfig: add post_mixer_normalization and post_mlp_normalization fields - DecoderBlock: apply post-norms to mixer/MLP outputs before residual add - Tests: test_qk_norm (4 cases) and test_post_norms (4 cases) Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/layers/attention/attention.py | 65 ++++++++++++-- fast_llm/layers/attention/config.py | 11 +++ fast_llm/layers/decoder/block.py | 14 +++ fast_llm/layers/decoder/config.py | 10 +++ tests/layers/test_decoder_block.py | 117 +++++++++++++++++++++++++ 5 files changed, 211 insertions(+), 6 deletions(-) create mode 100644 tests/layers/test_decoder_block.py diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index be40317f3..9e1022c2a 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -148,6 +148,18 @@ def __init__( # Rotary embeddings. self._rotary = self._config.rotary.get_layer(head_size_dim) + # QK norms (applied before RoPE, per head). + self.query_norm = ( + self._config.query_norm.get_layer(head_size_dim, lr_scale=self._lr_scale, peft=None) + if self._config.query_norm is not None + else None + ) + self.key_norm = ( + self._config.key_norm.get_layer(head_size_dim, lr_scale=self._lr_scale, peft=None) + if self._config.key_norm is not None + else None + ) + # Output. self.dense = self._config.dense_layer.get_layer( self._dense_dim, @@ -252,11 +264,34 @@ def _query_key_value_forward( # TODO: This is probably unnecessary. handle.wait() - query, key_value, rotary_context = self._rotary.forward_only( - query.unflatten(1, (self._local_heads, self._config.head_size)), - key_value.unflatten(1, (2 * self._local_head_groups, self._config.head_size)), - kwargs, - ) + query_unflat = query.unflatten(1, (self._local_heads, self._config.head_size)) + kv_unflat = key_value.unflatten(1, (2 * self._local_head_groups, self._config.head_size)) + + query_norm_ctx = None + if self._config.query_norm is not None: + if self.training: + with torch.enable_grad(): + query_leaf = query_unflat.detach().requires_grad_() + query_normed = self.query_norm(query_leaf) + query_norm_ctx = (query_leaf, query_normed) + query_unflat = query_normed.detach() + else: + query_unflat = self.query_norm(query_unflat) + + key_norm_ctx = None + if self._config.key_norm is not None: + # .contiguous() is required because RMSNormalization uses .view() internally. + key_unflat = kv_unflat[:, : self._local_head_groups, :].contiguous() + if self.training: + with torch.enable_grad(): + key_leaf = key_unflat.detach().requires_grad_() + key_normed = self.key_norm(key_leaf) + key_norm_ctx = (key_leaf, key_normed) + kv_unflat = torch.cat([key_normed.detach(), kv_unflat[:, self._local_head_groups :, :]], dim=1) + else: + kv_unflat = torch.cat([self.key_norm(key_unflat), kv_unflat[:, self._local_head_groups :, :]], dim=1) + + query, key_value, rotary_context = self._rotary.forward_only(query_unflat, kv_unflat, kwargs) if self._sequence_data_parallel_dim.group: # sequence dim may not be zero, but this needs to be handled after `handle.wait()` @@ -266,7 +301,13 @@ def _query_key_value_forward( if handle: handle.wait() - context = {"query": query_context, "key_value": key_value_context, "rotary": rotary_context} + context = { + "query": query_context, + "key_value": key_value_context, + "rotary": rotary_context, + "query_norm": query_norm_ctx, + "key_norm": key_norm_ctx, + } return query, key_value, context def _query_key_value_backward( @@ -283,6 +324,11 @@ def _query_key_value_backward( rotary_context = context.pop("rotary") query_grad, _ = self._rotary.backward(query_grad, None, rotary_context) + if (query_norm_ctx := context.pop("query_norm")) is not None: + query_leaf, query_normed = query_norm_ctx + query_normed.backward(query_grad) + query_grad = query_leaf.grad + # TODO: Overlap with both. input_grad = self.query.backward(query_grad.flatten(1), context.pop("query")) @@ -290,6 +336,13 @@ def _query_key_value_backward( handle.wait() _, key_value_grad = self._rotary.backward(None, key_value_grad, rotary_context) + + if (key_norm_ctx := context.pop("key_norm")) is not None: + key_leaf, key_normed = key_norm_ctx + key_grad = key_value_grad[:, : self._local_head_groups, :].contiguous() + key_normed.backward(key_grad) + key_value_grad = torch.cat([key_leaf.grad, key_value_grad[:, self._local_head_groups :, :]], dim=1) + key_value_grad = key_value_grad.flatten(1) if self._config.head_groups == 1 and (group := self._parallel_dim.group): diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index fcb5bfaf6..93596cfb2 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -6,6 +6,7 @@ from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert @@ -122,6 +123,16 @@ class AttentionConfig(MixerConfig): desc="The implementation to use for the attention layer. Default: `flash` if supported, otherwise `backup`.", hint=FieldHint.feature, ) + query_norm: NormalizationConfig | None = Field( + default=None, + desc="Normalization applied to query vectors before RoPE, per attention head. Set to `{type: rms_norm}` to enable.", + hint=FieldHint.architecture, + ) + key_norm: NormalizationConfig | None = Field( + default=None, + desc="Normalization applied to key vectors before RoPE, per attention head. Set to `{type: rms_norm}` to enable.", + hint=FieldHint.architecture, + ) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index a2f2d3519..34a62217c 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -88,6 +88,16 @@ def __init__( self._return_input = return_input self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + self.post_mixer_norm = ( + self._config.post_mixer_normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if self._config.post_mixer_normalization is not None + else None + ) + self.post_mlp_norm = ( + self._config.post_mlp_normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if self._config.post_mlp_normalization is not None + else None + ) self.mixer = self._config.mixer.get_layer( self._distributed_config, @@ -145,12 +155,16 @@ def forward( bias = None hidden_states = self._activation_distillation_loss(hidden_states, kwargs, losses, metrics) + if self.post_mixer_norm is not None: + hidden_states = self.post_mixer_norm(hidden_states) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) self._debug(input_, "mixer_residual", hidden_dims, kwargs) hidden_states = self.norm_2(input_) self._debug(hidden_states, "norm_2", hidden_dims, kwargs) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) + if self.post_mlp_norm is not None: + hidden_states = self.post_mlp_norm(hidden_states) with set_generator(generator): hidden_states = self._bias_dropout_add(hidden_states, bias, input_) self._debug(hidden_states, None, hidden_dims, kwargs) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 6ab259b2b..54877c647 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -218,6 +218,16 @@ class DecoderBlockConfig(BlockConfig): desc="Configuration for the block normalization layers.", hint=FieldHint.architecture, ) + post_mixer_normalization: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the mixer output before the residual add. Set to `{type: rms_norm}` to enable.", + hint=FieldHint.architecture, + ) + post_mlp_normalization: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the MLP output before the residual add. Set to `{type: rms_norm}` to enable.", + hint=FieldHint.architecture, + ) # TODO: Review names dropout: float = Field( default=0.0, diff --git a/tests/layers/test_decoder_block.py b/tests/layers/test_decoder_block.py new file mode 100644 index 000000000..37302f79d --- /dev/null +++ b/tests/layers/test_decoder_block.py @@ -0,0 +1,117 @@ +import dataclasses +import functools + +import pytest +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.decoder.block import DecoderBlock +from fast_llm.layers.decoder.config import DecoderBlockConfig +from tests.utils.utils import get_stage + +_NUM_TOKENS = 16 +_HIDDEN_SIZE = 64 +_HEADS = 4 +_KV_HEADS = 2 +_HEAD_SIZE = 16 +_INTERMEDIATE_SIZE = 128 + + +@dataclasses.dataclass +class PostNormTestConfig: + name: str + post_mixer_norm: bool = False + post_mlp_norm: bool = False + + def get_block_config(self) -> DecoderBlockConfig: + config_dict: dict = { + "mixer": { + "heads": _HEADS, + "head_groups": _KV_HEADS, + "head_size": _HEAD_SIZE, + "add_linear_biases": False, + "implementation": "backup", + }, + "mlp": { + "intermediate_size": _INTERMEDIATE_SIZE, + "add_linear_biases": False, + }, + "normalization": {"type": "rms_norm"}, + } + if self.post_mixer_norm: + config_dict["post_mixer_normalization"] = {"type": "rms_norm"} + if self.post_mlp_norm: + config_dict["post_mlp_normalization"] = {"type": "rms_norm"} + return DecoderBlockConfig.from_dict(config_dict) + + @functools.cached_property + def threshold(self) -> float: + return 1e-5 + + def expected_output(self, block: DecoderBlock, input_: torch.Tensor, kwargs: dict) -> torch.Tensor: + with torch.no_grad(): + norm1_out = block.norm_1(input_) + mixer_hidden, mixer_bias = block.mixer(norm1_out, kwargs) + if block.post_mixer_norm is not None: + mixer_hidden = block.post_mixer_norm(mixer_hidden) + if mixer_bias is not None: + mixer_hidden = mixer_hidden + mixer_bias + after_mixer = input_ + mixer_hidden + + norm2_out = block.norm_2(after_mixer) + mlp_hidden, mlp_bias = block.mlp(norm2_out, kwargs) + if block.post_mlp_norm is not None: + mlp_hidden = block.post_mlp_norm(mlp_hidden) + if mlp_bias is not None: + mlp_hidden = mlp_hidden + mlp_bias + return after_mixer + mlp_hidden + + +_base_post_norm_cases = [ + ("no_post_norms", {}), + ("post_mixer_norm", {"post_mixer_norm": True}), + ("post_mlp_norm", {"post_mlp_norm": True}), + ("both_post_norms", {"post_mixer_norm": True, "post_mlp_norm": True}), +] + +_post_norm_test_configs = [PostNormTestConfig(name=name, **kwargs) for name, kwargs in _base_post_norm_cases] + + +@pytest.mark.parametrize( + "test_config", + [pytest.param(c, id=c.name) for c in _post_norm_test_configs], +) +def test_post_norms(test_config: PostNormTestConfig): + distributed_config = DistributedConfig(use_cuda=torch.cuda.is_available()) + distributed = Distributed(distributed_config) + hidden_dim = TensorDim("hidden", _HIDDEN_SIZE) + block: DecoderBlock = test_config.get_block_config().get_layer( + distributed_config, hidden_dim, lr_scale=None, peft=None + ) + get_stage([block], distributed) + block.eval() + + device = distributed.device + input_ = torch.randn(_NUM_TOKENS, _HIDDEN_SIZE, device=device) + + token_dim = TensorDim("token", _NUM_TOKENS) + kwargs = { + AttentionKwargs.sequence_k_dim: TensorDim("sequence_k", _NUM_TOKENS), + AttentionKwargs.token_dim: token_dim, + AttentionKwargs.hidden_token_dim: token_dim, + AttentionKwargs.key_value_token_dim: token_dim, + AttentionKwargs.sequence_length: _NUM_TOKENS, + AttentionKwargs.document_index_k: torch.zeros(_NUM_TOKENS, dtype=torch.int64, device=device), + AttentionKwargs.document_index_q: torch.zeros(_NUM_TOKENS, dtype=torch.int64, device=device), + AttentionKwargs.device: device, + } + block.preprocess(kwargs) + + with torch.no_grad(): + output = block(input_, kwargs) + + expected = test_config.expected_output(block, input_, kwargs) + torch.testing.assert_close(output, expected, rtol=test_config.threshold, atol=test_config.threshold) From 8601ca7a09478f8625b0e8fc8ff0f985622d7903 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 27 Apr 2026 22:33:05 -0400 Subject: [PATCH 03/14] Expand test_attention with MQA/MHA/rotary/norm coverage Add MQA (kv_heads=1), MHA (kv_heads=heads), rotary, and query/key norm variants to the parametrized attention test, bringing it to 96 cases. The independent reference (plain F.linear + per-head einsum loop) now covers all combinations. Run entirely on GPU with TF32 disabled via a _no_tf32() context manager to keep precision tight without CPU-Triton conflicts. Co-Authored-By: Claude Sonnet 4.6 --- tests/layers/test_attention.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 05ee6c778..4c4754ee1 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -44,6 +44,8 @@ class AttentionTestConfig: head_size: int = _HEAD_SIZE causal: bool = True window_size: int | None = None + query_norm: bool = False + key_norm: bool = False rotary: bool = False rotary_theta: float = 10000.0 @@ -66,6 +68,10 @@ def get_attention_config(self, implementation: str = "backup") -> AttentionConfi } if self.window_size is not None: config["window_size"] = self.window_size + if self.query_norm: + config["query_norm"] = {"type": "rms_norm"} + if self.key_norm: + config["key_norm"] = {"type": "rms_norm"} if self.rotary: config["rotary"] = {"type": "default", "theta": self.rotary_theta} return AttentionConfig.from_dict(config) @@ -77,8 +83,8 @@ def expected_output( lengths: list[int], ) -> torch.Tensor: """ - Independent reference: plain F.linear + rotary + per-document einsum attention. - No calls to Fast-LLM attention internals. + Independent reference: plain F.linear + torch.rms_norm + rotary + per-document einsum attention. + No calls to Fast-LLM attention or norm internals. """ with torch.no_grad(): q = torch.nn.functional.linear(input_, attention.query.weight.detach()).unflatten( @@ -88,12 +94,21 @@ def expected_output( 1, (2 * self.kv_heads, self.head_size) ) + if self.query_norm: + q = torch.rms_norm(q, (self.head_size,), attention.query_norm.weight.detach(), 1e-5) + if self.key_norm: + key_normed = torch.rms_norm( + kv[:, : self.kv_heads, :], (self.head_size,), attention.key_norm.weight.detach(), 1e-5 + ) + kv = torch.cat([key_normed, kv[:, self.kv_heads :, :]], dim=1) + if self.rotary: freqs = _compute_rotary_freqs(input_.shape[0], self.head_size, self.rotary_theta, input_.device) q = _apply_rotary(q, freqs) k_rotated = _apply_rotary(kv[:, : self.kv_heads, :], freqs) kv = torch.cat([k_rotated, kv[:, self.kv_heads :, :]], dim=1) + k, v = kv[:, : self.kv_heads, :], kv[:, self.kv_heads :, :] scale = self.head_size**-0.5 @@ -136,9 +151,17 @@ def expected_output( ("causal_rotary", {"causal": True, "rotary": True}), ] +_attention_norm_variants = [ + ("no_norm", {}), + ("query_norm", {"query_norm": True}), + ("key_norm", {"key_norm": True}), + ("both_norms", {"query_norm": True, "key_norm": True}), +] + _attention_test_configs = [ - AttentionTestConfig(name=base_name, **base_kwargs) + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) for base_name, base_kwargs in _base_attention_cases + _attention_rotary_cases + for variant_name, variant_kwargs in _attention_norm_variants ] _attention_lengths = [ From 275f1b7d2940d8a2473db2cbd315fc29ebf423a6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 02:27:39 -0400 Subject: [PATCH 04/14] Add ProportionalRotary and rewrite test_rotary as parametrized suite Adds ProportionalRotaryConfig/ProportionalRotary for partial RoPE (partial_rotary_factor<1), where NoPE dimensions pass through via zero angle scales. Replaces the ad-hoc test_rotary with a single parametrized test covering default, big-theta, llama3, yarn, 2d, and proportional variants across multiple head sizes and sequence lengths. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/layers/attention/rotary/config.py | 28 +++++++++++++++++++++- fast_llm/layers/attention/rotary/rotary.py | 22 +++++++++++++++++ tests/layers/test_rotary.py | 23 ++++++++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 80f499748..588abb3bf 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -2,7 +2,7 @@ import math import typing -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.utils import Assert @@ -12,6 +12,7 @@ DefaultRotary, Llama3Rotary, NoRotary, + ProportionalRotary, Rotary, Rotary2D, YarnRotary, @@ -139,3 +140,28 @@ def _get_configurable_class(self) -> "type[Rotary2D]": from fast_llm.layers.attention.rotary.rotary import Rotary2D return Rotary2D + + +@config_class(dynamic_type={RotaryConfig: "proportional"}) +class ProportionalRotaryConfig(DefaultRotaryConfig): + """ + Rotary embeddings applied only to a leading fraction of head dimensions (NoPE for the rest). + Used by Gemma 4 global-attention layers (partial_rotary_factor=0.5). + """ + + _abstract = False + partial_rotary_factor: float = Field( + default=1.0, + desc="Fraction of head dimensions to apply rotary embeddings to.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + + def _validate(self) -> None: + super()._validate() + Assert.leq(self.partial_rotary_factor, 1.0) + + def _get_configurable_class(self) -> "type[ProportionalRotary]": + from fast_llm.layers.attention.rotary.rotary import ProportionalRotary + + return ProportionalRotary diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 7752e058c..e9e0e7578 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -13,6 +13,7 @@ DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, + ProportionalRotaryConfig, Rotary2DConfig, RotaryConfig, YarnRotaryConfig, @@ -269,6 +270,27 @@ def _get_correction(self, beta: float, dim: int) -> float: ) +class ProportionalRotary[ConfigType: ProportionalRotaryConfig](DefaultRotary[ConfigType]): + """ + Rotary embeddings applied only to the first rotary_dims head dimensions. + The remaining NoPE dimensions pass through unchanged (zero angle → identity rotation). + """ + + def __init__(self, config: ConfigType, head_size_dim: TensorDim) -> None: + super().__init__(config, head_size_dim) + self._rotary_dims = round(self._head_size * self._config.partial_rotary_factor) + Assert.gt(self._rotary_dims, 0) + Assert.multiple(self._rotary_dims, 2) + + def _get_angle_scales(self, head_size: int, device: torch.device) -> torch.Tensor: + rotary_pairs = self._rotary_dims // 2 + nope_pairs = head_size // 2 - rotary_pairs + scales = super()._get_angle_scales(head_size, device) + if nope_pairs == 0: + return scales + return torch.cat([scales[:rotary_pairs], scales.new_zeros(nope_pairs)]) + + class Rotary2D[ConfigType: Rotary2DConfig](RotaryBase[ConfigType]): _frequencies: torch.Tensor _config: ConfigType diff --git a/tests/layers/test_rotary.py b/tests/layers/test_rotary.py index 13f7575b6..7557912bd 100644 --- a/tests/layers/test_rotary.py +++ b/tests/layers/test_rotary.py @@ -10,6 +10,7 @@ from fast_llm.layers.attention.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, + ProportionalRotaryConfig, Rotary2DConfig, RotaryConfig, YarnRotaryConfig, @@ -77,6 +78,8 @@ class RotaryTestConfig: head_size: int rotary_type: str = "default" theta: float = 10000.0 + # proportional + partial_rotary_factor: float = 1.0 # llama3 and yarn scale_factor: float = 8.0 original_context_length: int = 8192 @@ -96,6 +99,8 @@ def attention_factor(self) -> float: def get_rotary_config(self) -> RotaryConfig: if self.rotary_type == "default": return DefaultRotaryConfig(theta=self.theta) + if self.rotary_type == "proportional": + return ProportionalRotaryConfig(theta=self.theta, partial_rotary_factor=self.partial_rotary_factor) if self.rotary_type == "llama3": return Llama3RotaryConfig( theta=self.theta, @@ -121,6 +126,12 @@ def reference_angle_scales(self) -> torch.Tensor: base = self.theta ** -torch.arange(0, 1, 2 / self.head_size, dtype=torch.float64) if self.rotary_type in ("default", "2d"): return base + if self.rotary_type == "proportional": + rotary_pairs = round(self.head_size * self.partial_rotary_factor) // 2 + nope_pairs = self.head_size // 2 - rotary_pairs + if nope_pairs == 0: + return base + return torch.cat([base[:rotary_pairs], base.new_zeros(nope_pairs)]) if self.rotary_type == "llama3": high_freq_wavelength = self.original_context_length / self.high_frequency_factor low_freq_wavelength = self.original_context_length / self.low_frequency_factor @@ -200,6 +211,18 @@ def reference_output( for head_size in _head_sizes ] +for _head_size in _head_sizes: + for _factor in [0.25, 0.5, 0.75, 1.0]: + if round(_head_size * _factor) % 2 == 0 and round(_head_size * _factor) > 0: + _rotary_test_configs.append( + RotaryTestConfig( + name=f"proportional_{int(_factor * 100)}pct_h{_head_size}", + head_size=_head_size, + rotary_type="proportional", + partial_rotary_factor=_factor, + ) + ) + _sequence_lengths = [8, 24] From 32f690b9be8df0aca08c2dd3281a99433fab008b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 02:27:52 -0400 Subject: [PATCH 05/14] Add value_norm, shared_key_value, and FixedRMSNorm (Gemma 4 attention) Adds FixedRMSNormConfig/FixedRMSNormalization, a no-weight RMS norm with triton (has_weight constexpr) and torch paths. Wires it into AttentionConfig as value_norm (NormalizationConfig|None), applying fixed-scale RMS norm to value projections per head. Also adds shared_key_value, which uses a single key projection reused as value with gradients summed back in the backward pass. Extends test_attention with value_norm and all_norms norm variants across all base cases, plus a shared_key_value case family. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/functional/triton/normalization.py | 56 +++++++---- fast_llm/layers/attention/attention.py | 97 +++++++++++++------ fast_llm/layers/attention/config.py | 10 ++ .../layers/common/normalization/config.py | 19 ++++ .../common/normalization/normalization.py | 22 +++++ tests/layers/test_attention.py | 54 +++++++++-- 6 files changed, 202 insertions(+), 56 deletions(-) diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index 7c25ce735..247d705b7 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -18,6 +18,7 @@ def triton_normalization_forward_kernel( n_cols, eps, has_bias: tl_constexpr, + has_weight: tl_constexpr, zero_centered: tl_constexpr, block_size: tl_constexpr, ): @@ -40,11 +41,13 @@ def triton_normalization_forward_kernel( tl.store(inv_var_ptr + row, inv_var) # Weight - weight = tl.load(weight_ptr + cols, mask=mask) - if zero_centered: - weight += 1 - - output = input_ * inv_var * weight + if has_weight: + weight = tl.load(weight_ptr + cols, mask=mask) + if zero_centered: + weight += 1 + output = input_ * inv_var * weight + else: + output = input_ * inv_var # Bias if has_bias: @@ -69,6 +72,7 @@ def triton_normalization_backward_kernel_1( n_rows, eps, has_bias: tl_constexpr, + has_weight: tl_constexpr, parameter_grad: tl_constexpr, zero_centered: tl_constexpr, block_size: tl_constexpr, @@ -87,10 +91,6 @@ def triton_normalization_backward_kernel_1( # Load data output = tl.load(output_ptr + offsets, mask=mask, other=0).to(tl.float32) grad_output = tl.load(grad_output_ptr + offsets, mask=mask, other=0).to(tl.float32) - weight = tl.load(weight_ptr + cols, mask=col_mask).to(tl.float32) - if zero_centered: - weight += 1 - inv_var = tl.load(inv_var_ptr + rows, mask=row_mask) # Bias @@ -99,9 +99,18 @@ def triton_normalization_backward_kernel_1( output = output - bias # Input grad - weight_regularised = tl.where(weight >= 0, tl.maximum(weight, eps), tl.minimum(weight, -eps)) - input_normalized = tl.where(mask, output / weight_regularised, 0.0) - weight_grad_output = tl.where(mask, weight * grad_output * inv_var, 0.0) + if has_weight: + weight = tl.load(weight_ptr + cols, mask=col_mask).to(tl.float32) + if zero_centered: + weight += 1 + weight_regularised = tl.where(weight >= 0, tl.maximum(weight, eps), tl.minimum(weight, -eps)) + input_normalized = tl.where(mask, output / weight_regularised, 0.0) + weight_grad_output = tl.where(mask, weight * grad_output * inv_var, 0.0) + else: + # weight == 1 everywhere: forward output = input * inv_var, so input_normalized = output + input_normalized = tl.where(mask, output, 0.0) + weight_grad_output = tl.where(mask, grad_output * inv_var, 0.0) + grad_input = weight_grad_output - input_normalized * ( tl.sum(input_normalized * weight_grad_output, axis=1)[:, None] / n_cols ) @@ -170,7 +179,7 @@ def triton_normalization_backward_kernel_2( def triton_normalization_forward( input_: torch.Tensor, - weight: torch.Tensor, + weight: torch.Tensor | None, bias: torch.Tensor | None, eps: float, training: bool, @@ -179,14 +188,15 @@ def triton_normalization_forward( # Note: Converting input automatically to training dtype to match Apex behaviour, # needed for full precision residual. # TODO: Review this? - assert weight.shape == input_.shape[-1:] - if bias is not None: - assert weight.shape == bias.shape + if weight is not None: + assert weight.shape == input_.shape[-1:] + if bias is not None: + assert weight.shape == bias.shape assert input_.is_contiguous() n_rows = input_.shape[:-1].numel() - n_cols = weight.numel() + n_cols = input_.shape[-1] - output = torch.empty_like(input_, dtype=weight.dtype) + output = torch.empty_like(input_, dtype=weight.dtype if weight is not None else input_.dtype) inv_var = torch.empty(n_rows, dtype=torch.float32, device=input_.device) block_size = triton.next_power_of_2(n_cols) @@ -202,6 +212,7 @@ def triton_normalization_forward( n_cols, eps, bias is not None, + weight is not None, zero_centered, block_size, num_warps=num_warps, @@ -217,16 +228,18 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin # We delete the context to prevent a memory leak context.clear() has_bias = bias is not None + has_weight = weight is not None - parameter_grad = weight.requires_grad - assert parameter_grad == hasattr(weight, "grad_buffer") + parameter_grad = weight.requires_grad if has_weight else False + if has_weight: + assert parameter_grad == hasattr(weight, "grad_buffer") if has_bias: assert parameter_grad == bias.requires_grad grad_output = grad_output.contiguous() n_rows = grad_output.shape[:-1].numel() - n_cols = weight.numel() + n_cols = grad_output.shape[-1] # TODO: Improve heuristics # The ones from triton tutorial (32, 128) are terrible. # These seem to match torch compile heuristics and were near-optimal for A100 tests with [8192, 4096], bf16. @@ -274,6 +287,7 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin n_rows, eps, has_bias, + has_weight, parameter_grad, zero_centered, block_size, diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 9e1022c2a..fae143238 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -104,12 +104,17 @@ def __init__( head_size_dim = TensorDim("head_size", self._config.head_size) query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, head_size_dim)) - key_value_dim = ConcatenatedTensorDim( - "key_value", - ( - CompositeTensorDim("key", (head_group_dim, head_size_dim)), - CompositeTensorDim("value", (head_group_dim, head_size_dim)), - ), + key_dim = CompositeTensorDim("key", (head_group_dim, head_size_dim)) + key_value_dim = ( + key_dim + if self._config.shared_key_value + else ConcatenatedTensorDim( + "key_value", + ( + key_dim, + CompositeTensorDim("value", (head_group_dim, head_size_dim)), + ), + ) ) self._dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim)) @@ -136,7 +141,7 @@ def __init__( lr_scale=self._lr_scale, peft=None if self._config.key_layer.apply_peft is None else self._peft, ) - if self._peft is not None and self._config.key_layer.apply_peft is None: + if self._peft is not None and self._config.key_layer.apply_peft is None and not self._config.shared_key_value: # Default: Apply to value only. # TODO: Avoid this hack. self.key_value = self._peft.apply_linear( @@ -148,7 +153,7 @@ def __init__( # Rotary embeddings. self._rotary = self._config.rotary.get_layer(head_size_dim) - # QK norms (applied before RoPE, per head). + # QKV norms (applied after projection, before RoPE). self.query_norm = ( self._config.query_norm.get_layer(head_size_dim, lr_scale=self._lr_scale, peft=None) if self._config.query_norm is not None @@ -159,6 +164,11 @@ def __init__( if self._config.key_norm is not None else None ) + self.value_norm = ( + self._config.value_norm.get_layer(head_size_dim, lr_scale=self._lr_scale, peft=None) + if self._config.value_norm is not None + else None + ) # Output. self.dense = self._config.dense_layer.get_layer( @@ -265,13 +275,17 @@ def _query_key_value_forward( handle.wait() query_unflat = query.unflatten(1, (self._local_heads, self._config.head_size)) - kv_unflat = key_value.unflatten(1, (2 * self._local_head_groups, self._config.head_size)) + if self._config.shared_key_value: + kv_unflat = key_value.unflatten(1, (self._local_head_groups, self._config.head_size)) + kv_unflat = torch.cat([kv_unflat, kv_unflat], dim=1) + else: + kv_unflat = key_value.unflatten(1, (2 * self._local_head_groups, self._config.head_size)) query_norm_ctx = None if self._config.query_norm is not None: if self.training: with torch.enable_grad(): - query_leaf = query_unflat.detach().requires_grad_() + query_leaf = query_unflat.contiguous().detach().requires_grad_() query_normed = self.query_norm(query_leaf) query_norm_ctx = (query_leaf, query_normed) query_unflat = query_normed.detach() @@ -279,17 +293,31 @@ def _query_key_value_forward( query_unflat = self.query_norm(query_unflat) key_norm_ctx = None - if self._config.key_norm is not None: - # .contiguous() is required because RMSNormalization uses .view() internally. - key_unflat = kv_unflat[:, : self._local_head_groups, :].contiguous() - if self.training: - with torch.enable_grad(): - key_leaf = key_unflat.detach().requires_grad_() - key_normed = self.key_norm(key_leaf) - key_norm_ctx = (key_leaf, key_normed) - kv_unflat = torch.cat([key_normed.detach(), kv_unflat[:, self._local_head_groups :, :]], dim=1) - else: - kv_unflat = torch.cat([self.key_norm(key_unflat), kv_unflat[:, self._local_head_groups :, :]], dim=1) + value_norm_ctx = None + if self._config.key_norm is not None or self._config.value_norm is not None: + key_unflat, value_unflat = kv_unflat.chunk(2, dim=1) + if self._config.key_norm is not None: + # .contiguous() is required because RMSNormalization uses .view() internally. + key_unflat = key_unflat.contiguous() + if self.training: + with torch.enable_grad(): + key_leaf = key_unflat.detach().requires_grad_() + key_normed = self.key_norm(key_leaf) + key_norm_ctx = (key_leaf, key_normed) + key_unflat = key_normed.detach() + else: + key_unflat = self.key_norm(key_unflat) + if self._config.value_norm is not None: + value_unflat = value_unflat.contiguous() + if self.training: + with torch.enable_grad(): + value_leaf = value_unflat.detach().requires_grad_() + value_normed = self.value_norm(value_leaf) + value_norm_ctx = (value_leaf, value_normed) + value_unflat = value_normed.detach() + else: + value_unflat = self.value_norm(value_unflat) + kv_unflat = torch.cat([key_unflat, value_unflat], dim=1) query, key_value, rotary_context = self._rotary.forward_only(query_unflat, kv_unflat, kwargs) @@ -307,6 +335,7 @@ def _query_key_value_forward( "rotary": rotary_context, "query_norm": query_norm_ctx, "key_norm": key_norm_ctx, + "value_norm": value_norm_ctx, } return query, key_value, context @@ -337,13 +366,25 @@ def _query_key_value_backward( _, key_value_grad = self._rotary.backward(None, key_value_grad, rotary_context) - if (key_norm_ctx := context.pop("key_norm")) is not None: - key_leaf, key_normed = key_norm_ctx - key_grad = key_value_grad[:, : self._local_head_groups, :].contiguous() - key_normed.backward(key_grad) - key_value_grad = torch.cat([key_leaf.grad, key_value_grad[:, self._local_head_groups :, :]], dim=1) - - key_value_grad = key_value_grad.flatten(1) + key_norm_ctx = context.pop("key_norm") + value_norm_ctx = context.pop("value_norm") + if key_norm_ctx is not None or value_norm_ctx is not None: + key_grad, value_grad = key_value_grad.chunk(2, dim=1) + if key_norm_ctx is not None: + key_leaf, key_normed = key_norm_ctx + key_normed.backward(key_grad.contiguous()) + key_grad = key_leaf.grad + if value_norm_ctx is not None: + value_leaf, value_normed = value_norm_ctx + value_normed.backward(value_grad.contiguous()) + value_grad = value_leaf.grad + key_value_grad = torch.cat([key_grad, value_grad], dim=1) + + if self._config.shared_key_value: + key_grad, value_grad = key_value_grad.chunk(2, dim=1) + key_value_grad = (key_grad + value_grad).flatten(1) + else: + key_value_grad = key_value_grad.flatten(1) if self._config.head_groups == 1 and (group := self._parallel_dim.group): if self._sequence_parallel: diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 93596cfb2..cc5d80e88 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -133,6 +133,16 @@ class AttentionConfig(MixerConfig): desc="Normalization applied to key vectors before RoPE, per attention head. Set to `{type: rms_norm}` to enable.", hint=FieldHint.architecture, ) + value_norm: NormalizationConfig | None = Field( + default=None, + desc="Normalization applied to value projections per head before attention. Use `{type: fixed_rms_norm}` for a no-weight RMS norm.", + hint=FieldHint.architecture, + ) + shared_key_value: bool = Field( + default=False, + desc="Use one shared key/value projection. The projected key is reused as value before separate K/V norms.", + hint=FieldHint.architecture, + ) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 274215bf2..c84b055c6 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -138,6 +138,25 @@ def module_class(self): return RMSNormalization +@config_class(dynamic_type={NormalizationConfig: "fixed_rms_norm"}) +class FixedRMSNormConfig(NormalizationConfig): + """RMS normalization without a learnable weight (fixed unit scale). Used for value norms in Gemma-family models.""" + + _abstract = False + epsilon: float = Field( + default=1e-5, + desc="Regularizer for the division.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + + @property + def module_class(self) -> type["Normalization"]: + from fast_llm.layers.common.normalization.normalization import FixedRMSNormalization + + return FixedRMSNormalization + + @config_class(dynamic_type={NormalizationConfig: "gated_rms_norm"}) class GatedRMSNormalizationConfig(RMSNormalizationConfig): """Configuration for gated RMS normalization, which applies a learned activation gate alongside the norm weight.""" diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 2858b9370..dda12f17b 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -9,6 +9,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.normalization.config import ( + FixedRMSNormConfig, GatedRMSNormalizationConfig, LayerNormalizationConfig, NoNormalizationConfig, @@ -301,6 +302,27 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) +class FixedRMSNormalization[ConfigType: FixedRMSNormConfig](Normalization[ConfigType]): + """RMS normalization with no learnable weight (fixed unit scale).""" + + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) + self._normalized_shape = (hidden_dim.size,) + if TritonConfig.enabled(torch.device("cuda")): + self._forward = self._forward_triton + else: + self._forward = self._forward_torch + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) + + def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: + return triton_normalization_autograd(input_, None, None, self._config.epsilon, self.training, False) + + def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: + return torch.rms_norm(input_, self._normalized_shape, None, self._config.epsilon) + + class GatedRMSNormalization[ConfigType: GatedRMSNormalizationConfig](RMSNormalization[ConfigType], torch.nn.Module): """ A gated RMS normalization layer. diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 4c4754ee1..3cbb3e1c3 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -46,6 +46,8 @@ class AttentionTestConfig: window_size: int | None = None query_norm: bool = False key_norm: bool = False + value_norm: bool = False + shared_key_value: bool = False rotary: bool = False rotary_theta: float = 10000.0 @@ -72,6 +74,10 @@ def get_attention_config(self, implementation: str = "backup") -> AttentionConfi config["query_norm"] = {"type": "rms_norm"} if self.key_norm: config["key_norm"] = {"type": "rms_norm"} + if self.value_norm: + config["value_norm"] = {"type": "fixed_rms_norm"} + if self.shared_key_value: + config["shared_key_value"] = True if self.rotary: config["rotary"] = {"type": "default", "theta": self.rotary_theta} return AttentionConfig.from_dict(config) @@ -90,9 +96,15 @@ def expected_output( q = torch.nn.functional.linear(input_, attention.query.weight.detach()).unflatten( 1, (self.heads, self.head_size) ) - kv = torch.nn.functional.linear(input_, attention.key_value.weight.detach()).unflatten( - 1, (2 * self.kv_heads, self.head_size) - ) + if self.shared_key_value: + key_projected = torch.nn.functional.linear(input_, attention.key_value.weight.detach()).unflatten( + 1, (self.kv_heads, self.head_size) + ) + kv = torch.cat([key_projected, key_projected], dim=1) + else: + kv = torch.nn.functional.linear(input_, attention.key_value.weight.detach()).unflatten( + 1, (2 * self.kv_heads, self.head_size) + ) if self.query_norm: q = torch.rms_norm(q, (self.head_size,), attention.query_norm.weight.detach(), 1e-5) @@ -101,6 +113,9 @@ def expected_output( kv[:, : self.kv_heads, :], (self.head_size,), attention.key_norm.weight.detach(), 1e-5 ) kv = torch.cat([key_normed, kv[:, self.kv_heads :, :]], dim=1) + if self.value_norm: + value_normed = torch.rms_norm(kv[:, self.kv_heads :, :], (self.head_size,), None, 1e-5) + kv = torch.cat([kv[:, : self.kv_heads, :], value_normed], dim=1) if self.rotary: freqs = _compute_rotary_freqs(input_.shape[0], self.head_size, self.rotary_theta, input_.device) @@ -155,15 +170,40 @@ def expected_output( ("no_norm", {}), ("query_norm", {"query_norm": True}), ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), ("both_norms", {"query_norm": True, "key_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), ] -_attention_test_configs = [ - AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) - for base_name, base_kwargs in _base_attention_cases + _attention_rotary_cases - for variant_name, variant_kwargs in _attention_norm_variants +_attention_shared_key_value_cases = [ + ("shared_key_value", {"shared_key_value": True}), ] +_attention_shared_key_value_norm_variants = [ + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), +] + +_attention_test_configs = ( + [ + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) + for base_name, base_kwargs in _base_attention_cases + for variant_name, variant_kwargs in _attention_norm_variants + ] + + [ + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) + for base_name, base_kwargs in _attention_rotary_cases + for variant_name, variant_kwargs in _attention_norm_variants + ] + + [ + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) + for base_name, base_kwargs in _attention_shared_key_value_cases + for variant_name, variant_kwargs in _attention_shared_key_value_norm_variants + ] +) + _attention_lengths = [ [15], [6, 9], From 41ced2d475cc615c7e7814ef271b42ad6359b8f0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 04:01:42 -0400 Subject: [PATCH 06/14] Fix in-place rotary corruption of query_norm backward context triton_rotary_ wrote results in-place, silently corrupting the saved norm output when query_norm was active (both tensors shared storage via .detach()). Add output_ptr to the Triton kernel and inplace_query flag through the rotary layer so the query gets a fresh allocation when a query_norm context is live. Also rename *_norm_ctx -> *_norm_context for consistency. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/functional/triton/rotary.py | 20 ++++++++--- fast_llm/layers/attention/attention.py | 40 +++++++++++---------- fast_llm/layers/attention/rotary/rotary.py | 42 +++++++++++++++++----- 3 files changed, 70 insertions(+), 32 deletions(-) diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index f07046a52..be0fa5de2 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -9,6 +9,7 @@ @triton_jit() def triton_rotary_kernel( input_ptr, + output_ptr, frequencies_ptr, stride_0, stride_1, @@ -30,6 +31,8 @@ def triton_rotary_kernel( input_offsets = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + stride_2 * head_offsets + offsets[None, :] input_re_ptr = input_ptr + input_offsets input_im_ptr = input_re_ptr + rotary_dim + output_re_ptr = output_ptr + input_offsets + output_im_ptr = output_re_ptr + rotary_dim if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: input_re = tl.load(input_re_ptr).to(tl.float32) @@ -54,11 +57,11 @@ def triton_rotary_kernel( out_im = input_im * frequencies_re + input_re * frequencies_im if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - tl.store(input_re_ptr, out_re) - tl.store(input_im_ptr, out_im) + tl.store(output_re_ptr, out_re) + tl.store(output_im_ptr, out_im) else: - tl.store(input_re_ptr, out_re, mask=mask) # noqa - tl.store(input_im_ptr, out_im, mask=mask) + tl.store(output_re_ptr, out_re, mask=mask) # noqa + tl.store(output_im_ptr, out_im, mask=mask) def triton_rotary_( @@ -66,19 +69,27 @@ def triton_rotary_( frequencies: torch.Tensor, is_key_value: bool = False, backward: bool = False, + inplace: bool = True, ) -> torch.Tensor: # TODO: Improve assumptions. # TODO: Make a transposed version to avoid contiguous call in key backward. # TODO: Improve block size heuristics. out = input_ + write = input_ if input_.stride(-1) != 1: # TODO: Make a transposed version to avoid contiguous call in key backward. input_ = input_.contiguous() + write = input_ + if not inplace: + out = torch.empty_like(input_) + write = out if input_.ndim == 3: input_ = input_.unsqueeze(0) + write = write.unsqueeze(0) frequencies = frequencies.unsqueeze(0) if is_key_value: input_ = input_.chunk(2, dim=-2)[0] + write = write.chunk(2, dim=-2)[0] batch_size, seq_len, num_heads, head_size = input_.shape rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) @@ -89,6 +100,7 @@ def triton_rotary_( # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( input_, + write, frequencies, input_.stride(0), input_.stride(1), diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index fae143238..a3b9f41f3 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -281,19 +281,19 @@ def _query_key_value_forward( else: kv_unflat = key_value.unflatten(1, (2 * self._local_head_groups, self._config.head_size)) - query_norm_ctx = None + query_norm_context = None if self._config.query_norm is not None: if self.training: with torch.enable_grad(): query_leaf = query_unflat.contiguous().detach().requires_grad_() query_normed = self.query_norm(query_leaf) - query_norm_ctx = (query_leaf, query_normed) + query_norm_context = (query_leaf, query_normed) query_unflat = query_normed.detach() else: query_unflat = self.query_norm(query_unflat) - key_norm_ctx = None - value_norm_ctx = None + key_norm_context = None + value_norm_context = None if self._config.key_norm is not None or self._config.value_norm is not None: key_unflat, value_unflat = kv_unflat.chunk(2, dim=1) if self._config.key_norm is not None: @@ -303,7 +303,7 @@ def _query_key_value_forward( with torch.enable_grad(): key_leaf = key_unflat.detach().requires_grad_() key_normed = self.key_norm(key_leaf) - key_norm_ctx = (key_leaf, key_normed) + key_norm_context = (key_leaf, key_normed) key_unflat = key_normed.detach() else: key_unflat = self.key_norm(key_unflat) @@ -313,13 +313,15 @@ def _query_key_value_forward( with torch.enable_grad(): value_leaf = value_unflat.detach().requires_grad_() value_normed = self.value_norm(value_leaf) - value_norm_ctx = (value_leaf, value_normed) + value_norm_context = (value_leaf, value_normed) value_unflat = value_normed.detach() else: value_unflat = self.value_norm(value_unflat) kv_unflat = torch.cat([key_unflat, value_unflat], dim=1) - query, key_value, rotary_context = self._rotary.forward_only(query_unflat, kv_unflat, kwargs) + query, key_value, rotary_context = self._rotary.forward_only( + query_unflat, kv_unflat, kwargs, inplace_query=query_norm_context is None + ) if self._sequence_data_parallel_dim.group: # sequence dim may not be zero, but this needs to be handled after `handle.wait()` @@ -333,9 +335,9 @@ def _query_key_value_forward( "query": query_context, "key_value": key_value_context, "rotary": rotary_context, - "query_norm": query_norm_ctx, - "key_norm": key_norm_ctx, - "value_norm": value_norm_ctx, + "query_norm": query_norm_context, + "key_norm": key_norm_context, + "value_norm": value_norm_context, } return query, key_value, context @@ -353,8 +355,8 @@ def _query_key_value_backward( rotary_context = context.pop("rotary") query_grad, _ = self._rotary.backward(query_grad, None, rotary_context) - if (query_norm_ctx := context.pop("query_norm")) is not None: - query_leaf, query_normed = query_norm_ctx + if (query_norm_context := context.pop("query_norm")) is not None: + query_leaf, query_normed = query_norm_context query_normed.backward(query_grad) query_grad = query_leaf.grad @@ -366,16 +368,16 @@ def _query_key_value_backward( _, key_value_grad = self._rotary.backward(None, key_value_grad, rotary_context) - key_norm_ctx = context.pop("key_norm") - value_norm_ctx = context.pop("value_norm") - if key_norm_ctx is not None or value_norm_ctx is not None: + key_norm_context = context.pop("key_norm") + value_norm_context = context.pop("value_norm") + if key_norm_context is not None or value_norm_context is not None: key_grad, value_grad = key_value_grad.chunk(2, dim=1) - if key_norm_ctx is not None: - key_leaf, key_normed = key_norm_ctx + if key_norm_context is not None: + key_leaf, key_normed = key_norm_context key_normed.backward(key_grad.contiguous()) key_grad = key_leaf.grad - if value_norm_ctx is not None: - value_leaf, value_normed = value_norm_ctx + if value_norm_context is not None: + value_leaf, value_normed = value_norm_context value_normed.backward(value_grad.contiguous()) value_grad = value_leaf.grad key_value_grad = torch.cat([key_grad, value_grad], dim=1) diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index e9e0e7578..4c30d0a48 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -109,7 +109,11 @@ def forward( @abc.abstractmethod def forward_only( - self, query: torch.Tensor | None, key_value: torch.Tensor | None, kwargs: dict[str, typing.Any] + self, + query: torch.Tensor | None, + key_value: torch.Tensor | None, + kwargs: dict[str, typing.Any], + inplace_query: bool = True, ) -> tuple[torch.Tensor | None, torch.Tensor | None, typing.Any]: pass @@ -130,7 +134,11 @@ def forward( return query, key_value def forward_only( - self, query: torch.Tensor | None, key_value: torch.Tensor | None, kwargs: dict[str, typing.Any] + self, + query: torch.Tensor | None, + key_value: torch.Tensor | None, + kwargs: dict[str, typing.Any], + inplace_query: bool = True, ) -> tuple[torch.Tensor | None, torch.Tensor | None, typing.Any]: return query, key_value, None @@ -148,19 +156,35 @@ def _forward( key_value: torch.Tensor | None, frequencies: torch.Tensor, backward: bool = False, + inplace_query: bool = True, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - rotary_fn = triton_rotary_ if TritonConfig.enabled(frequencies.device) else rotary_embeddings_real - query = None if query is None else rotary_fn(query, frequencies, backward=backward) - key_value = ( - None if key_value is None else rotary_fn(key_value, frequencies, is_key_value=True, backward=backward) - ) + if TritonConfig.enabled(frequencies.device): + query = ( + None if query is None else triton_rotary_(query, frequencies, backward=backward, inplace=inplace_query) + ) + key_value = ( + None + if key_value is None + else triton_rotary_(key_value, frequencies, is_key_value=True, backward=backward) + ) + else: + query = None if query is None else rotary_embeddings_real(query, frequencies, backward=backward) + key_value = ( + None + if key_value is None + else rotary_embeddings_real(key_value, frequencies, is_key_value=True, backward=backward) + ) return query, key_value def forward_only( - self, query: torch.Tensor | None, key_value: torch.Tensor | None, kwargs: dict[str, typing.Any] + self, + query: torch.Tensor | None, + key_value: torch.Tensor | None, + kwargs: dict[str, typing.Any], + inplace_query: bool = True, ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor]: frequencies: torch.Tensor = kwargs[AttentionKwargs.rotary_freq] - query, key_value = self._forward(query, key_value, frequencies, backward=False) + query, key_value = self._forward(query, key_value, frequencies, backward=False, inplace_query=inplace_query) return query, key_value, frequencies def backward( From 2f58791c5dc2c970e0c2cdec5df9fbff5068285d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 21:12:50 -0400 Subject: [PATCH 07/14] Add HybridMoEMLP and independent pre-norm controls for decoder block - Add HybridMoEMLPConfig/HybridMoEMLP: always-active dense MLP + top-K routed experts with optional per-path pre/post norms - Add pre_mixer_normalization and pre_mlp_normalization to DecoderBlockConfig so norm_1 and norm_2 can be configured independently; normalization remains the shared default when either is unset - Add tests/layers/test_mlp.py covering HybridMoEMLP composition and norms Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/layers/decoder/block.py | 12 +- fast_llm/layers/decoder/config.py | 13 ++- fast_llm/layers/decoder/mlp/config.py | 49 ++++++++- .../layers/decoder/mlp/mixture_of_experts.py | 97 ++++++++++++++++- tests/layers/test_mlp.py | 103 ++++++++++++++++++ 5 files changed, 269 insertions(+), 5 deletions(-) create mode 100644 tests/layers/test_mlp.py diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 34a62217c..d41283c9d 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -86,8 +86,16 @@ def __init__( ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input = return_input - self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + self.norm_1 = ( + self._config.normalization + if self._config.pre_mixer_normalization is None + else self._config.pre_mixer_normalization + ).get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + self.norm_2 = ( + self._config.normalization + if self._config.pre_mlp_normalization is None + else self._config.pre_mlp_normalization + ).get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) self.post_mixer_norm = ( self._config.post_mixer_normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) if self._config.post_mixer_normalization is not None diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 54877c647..f949ec714 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -215,7 +215,18 @@ class DecoderBlockConfig(BlockConfig): ) # TODO: Review names normalization: NormalizationConfig = Field( - desc="Configuration for the block normalization layers.", + desc="Configuration for the block normalization layers. Used as default for `pre_mixer_normalization` and `pre_mlp_normalization` when not set.", + hint=FieldHint.architecture, + ) + pre_mixer_normalization: NormalizationConfig | None = Field( + default=None, + desc="Normalization applied to the residual before the mixer. Defaults to `normalization` when not set.", + hint=FieldHint.architecture, + ) + pre_mlp_normalization: NormalizationConfig | None = Field( + default=None, + desc="Normalization applied to the residual before the MLP. Defaults to `normalization` when not set." + " Set to `{type: none}` to disable independently of the pre-mixer norm.", hint=FieldHint.architecture, ) post_mixer_normalization: NormalizationConfig | None = Field( diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 997cf9d2a..0abdfe70f 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -5,11 +5,12 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.decoder.config import MLPBaseConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.decoder.mlp.mixture_of_experts import MixtureOfExpertMLP + from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP, MixtureOfExpertMLP from fast_llm.layers.decoder.mlp.mlp import MLP @@ -164,3 +165,49 @@ def _validate(self) -> None: super()._validate() Assert.leq(self.shared_experts, self.experts) Assert.leq(self.shared_experts + self.experts_per_token, self.experts) + + +@config_class(dynamic_type={MLPBaseConfig: "hybrid_moe"}) +class HybridMoEMLPConfig(MLPBaseConfig): + """Configuration for a MoE layer combining an always-active dense MLP with top-K routed experts.""" + + _abstract = False + + dense: MLPConfig = Field( + desc="Configuration for the always-active dense MLP.", + hint=FieldHint.architecture, + ) + routed: MoEMLPConfig = Field( + desc="Configuration for the top-K routed expert MLP.", + hint=FieldHint.architecture, + ) + dense_pre_norm: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the dense MLP input.", + hint=FieldHint.architecture, + ) + dense_post_norm: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the dense MLP output before summing.", + hint=FieldHint.architecture, + ) + moe_pre_norm: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the routed MLP input.", + hint=FieldHint.architecture, + ) + moe_post_norm: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the routed MLP output before summing.", + hint=FieldHint.architecture, + ) + + def _validate(self) -> None: + super()._validate() + Assert.eq(self.routed.shared_experts, 0) + + @property + def layer_class(self) -> "type[HybridMoEMLP]": + from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP + + return HybridMoEMLP diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 89979bd18..943a971b8 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -16,12 +16,22 @@ from fast_llm.functional.utils import AuxiliaryLoss from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEImplementation, MoEMLPConfig, RoutingType +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.decoder.mlp.config import ( + HybridMoEMLPConfig, + MLPLossNames, + MoEImplementation, + MoEMLPConfig, + RoutingType, +) from fast_llm.layers.decoder.mlp.mlp import MLPBase from fast_llm.layers.language_model.loss.z_loss import z_loss from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.engine.distributed.distributed import Distributed + logger = logging.getLogger(__name__) @@ -266,6 +276,91 @@ def get_loss_definitions(self) -> list[LossDef]: return loss_definitions +class HybridMoEMLP[ConfigType: HybridMoEMLPConfig](BlockWithBias[ConfigType]): + """ + MoE MLP that runs an always-active dense MLP alongside top-K routed experts and sums their outputs. + """ + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + output_dim: TensorDim | None = None, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias + ) + self._output_dim = self._hidden_dim if output_dim is None else output_dim + self.dense = config.dense.get_layer( + distributed_config, hidden_dim, output_dim=output_dim, lr_scale=lr_scale, peft=peft, return_bias=True + ) + self.routed = config.routed.get_layer( + distributed_config, hidden_dim, output_dim=output_dim, lr_scale=lr_scale, peft=peft, return_bias=True + ) + self.dense_pre_norm = ( + config.dense_pre_norm.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if config.dense_pre_norm is not None + else None + ) + self.dense_post_norm = ( + config.dense_post_norm.get_layer(self._output_dim, lr_scale=self._lr_scale, peft=self._peft) + if config.dense_post_norm is not None + else None + ) + self.moe_pre_norm = ( + config.moe_pre_norm.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if config.moe_pre_norm is not None + else None + ) + self.moe_post_norm = ( + config.moe_post_norm.get_layer(self._output_dim, lr_scale=self._lr_scale, peft=self._peft) + if config.moe_post_norm is not None + else None + ) + + def setup(self, distributed: "Distributed") -> None: + super().setup(distributed) + self.dense.setup(distributed) + self.routed.setup(distributed) + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if isinstance(input_, TensorMeta): + return ( + TensorMeta.from_dims( + input_.dims[:-1] + (self._output_dim,), tensor_name="MLP output", dtype=input_.dtype + ), + None, + ) + dense_input = self.dense_pre_norm(input_) if self.dense_pre_norm is not None else input_ + moe_input = self.moe_pre_norm(input_) if self.moe_pre_norm is not None else input_ + dense_out, dense_bias = self.dense(dense_input, kwargs, losses, metrics) + routed_out, _ = self.routed(moe_input, kwargs, losses, metrics) + if self.dense_post_norm is not None: + dense_out = self.dense_post_norm(dense_out) + if self.moe_post_norm is not None: + routed_out = self.moe_post_norm(routed_out) + return dense_out + routed_out, dense_bias + + def get_loss_definitions(self) -> list[LossDef]: + return self.routed.get_loss_definitions() + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + return self.dense.get_compute_usage(input_, kwargs, config) + self.routed.get_compute_usage( + input_, kwargs, config + ) + + def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" with torch.no_grad(): diff --git a/tests/layers/test_mlp.py b/tests/layers/test_mlp.py new file mode 100644 index 000000000..8ca5a3d66 --- /dev/null +++ b/tests/layers/test_mlp.py @@ -0,0 +1,103 @@ +import dataclasses + +import pytest +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.decoder.mlp.config import HybridMoEMLPConfig +from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP +from fast_llm.utils import Assert +from tests.utils.utils import get_stage + +_NUM_TOKENS = 16 +_HIDDEN_SIZE = 64 +_INTERMEDIATE_SIZE = 32 +_EXPERTS = 4 + +_NORM = {"type": "rms_norm"} + + +@dataclasses.dataclass +class HybridMoEMLPTestConfig: + name: str + gated: bool = False + experts_per_token: int = 1 + dense_pre_norm: bool = False + dense_post_norm: bool = False + moe_pre_norm: bool = False + moe_post_norm: bool = False + + def get_mlp_config(self) -> HybridMoEMLPConfig: + return HybridMoEMLPConfig.from_dict( + { + "dense": { + "intermediate_size": _INTERMEDIATE_SIZE, + "gated": self.gated, + "add_linear_biases": False, + }, + "routed": { + "intermediate_size": _INTERMEDIATE_SIZE, + "gated": self.gated, + "add_linear_biases": False, + "experts": _EXPERTS, + "experts_per_token": self.experts_per_token, + }, + **({"dense_pre_norm": _NORM} if self.dense_pre_norm else {}), + **({"dense_post_norm": _NORM} if self.dense_post_norm else {}), + **({"moe_pre_norm": _NORM} if self.moe_pre_norm else {}), + **({"moe_post_norm": _NORM} if self.moe_post_norm else {}), + } + ) + + def expected_output(self, hybrid: HybridMoEMLP, input_: torch.Tensor, kwargs: dict) -> torch.Tensor: + with torch.no_grad(): + dense_input = hybrid.dense_pre_norm(input_) if hybrid.dense_pre_norm is not None else input_ + moe_input = hybrid.moe_pre_norm(input_) if hybrid.moe_pre_norm is not None else input_ + dense_out, _ = hybrid.dense(dense_input, kwargs) + routed_out, _ = hybrid.routed(moe_input, kwargs) + if hybrid.dense_post_norm is not None: + dense_out = hybrid.dense_post_norm(dense_out) + if hybrid.moe_post_norm is not None: + routed_out = hybrid.moe_post_norm(routed_out) + return dense_out + routed_out + + +_test_configs = [ + HybridMoEMLPTestConfig(name="basic"), + HybridMoEMLPTestConfig(name="gated", gated=True), + HybridMoEMLPTestConfig(name="topk2", experts_per_token=2), + HybridMoEMLPTestConfig(name="gated_topk2", gated=True, experts_per_token=2), + HybridMoEMLPTestConfig(name="pre_norms", dense_pre_norm=True, moe_pre_norm=True), + HybridMoEMLPTestConfig(name="post_norms", dense_post_norm=True, moe_post_norm=True), + HybridMoEMLPTestConfig( + name="all_norms", dense_pre_norm=True, dense_post_norm=True, moe_pre_norm=True, moe_post_norm=True + ), + HybridMoEMLPTestConfig(name="asymmetric_norms", dense_pre_norm=True, moe_post_norm=True), +] + + +@pytest.mark.parametrize("config", [pytest.param(c, id=c.name) for c in _test_configs]) +def test_hybrid_moe_mlp(config: HybridMoEMLPTestConfig) -> None: + distributed_config = DistributedConfig(use_cuda=torch.cuda.is_available()) + distributed = Distributed(distributed_config) + device = distributed.device + hidden_dim = TensorDim("hidden", _HIDDEN_SIZE) + + hybrid: HybridMoEMLP = config.get_mlp_config().get_layer( + distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + get_stage([hybrid], distributed) + hybrid.eval() + + input_ = torch.randn(_NUM_TOKENS, _HIDDEN_SIZE, device=device) + token_dim = TensorDim("tokens", _NUM_TOKENS) + kwargs = {BlockKwargs.hidden_token_dim: token_dim} + + with torch.no_grad(): + output = hybrid(input_, kwargs) + + expected = config.expected_output(hybrid, input_, kwargs) + Assert.rms_close_relative(output, expected, 1e-5, 1e-7) From b86bd040c0b7463b8653bde8be7de5cc681f85b5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 21:34:18 -0400 Subject: [PATCH 08/14] Add Gemma 4 HuggingFace checkpoint converter Adds import/export support for Gemma 4 text models (gemma4 format): - Pattern decoder with alternating sliding-window and global attention - Per-head query/key/value norms, post-attention and post-MLP norms - Partial RoPE for global attention layers - Hybrid dense+MoE blocks with pre/post norms - Tied embeddings with sqrt(hidden_size) embedding scale - Logit softcapping Exports `hidden_size_per_layer_input: 0` to disable Per-Layer Embeddings (PLE) in the native HuggingFace model; TODO to implement PLE in Fast-LLM. Adds `gemma4` model testing config and registers the format with GPTModelConfig and AutoGPTHuggingfaceCheckpointHandler. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/models/gpt/config.py | 2 + fast_llm/models/gpt/conversion/auto.py | 3 + fast_llm/models/gpt/conversion/config.py | 4 + fast_llm/models/gpt/conversion/gemma4.py | 608 +++++++++++++++++++++++ tests/utils/model_configs.py | 69 +++ 5 files changed, 686 insertions(+) create mode 100644 fast_llm/models/gpt/conversion/gemma4.py diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 770139816..71981ba23 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -16,6 +16,7 @@ AutoGPTHuggingfaceCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, + Gemma4CheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -72,6 +73,7 @@ class GPTModelConfig(FastLLMModelConfig): DiffusionLlamaCheckpointFormat, AprielHybridSSMCheckpointFormat, Apriel2TextCheckpointFormat, + Gemma4CheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion/auto.py b/fast_llm/models/gpt/conversion/auto.py index 696b4f4ce..20842d611 100644 --- a/fast_llm/models/gpt/conversion/auto.py +++ b/fast_llm/models/gpt/conversion/auto.py @@ -9,6 +9,7 @@ AprielHybridSSMCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, + Gemma4CheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -17,6 +18,7 @@ ) from fast_llm.models.gpt.conversion.diffusion_dream import DiffusionDreamHuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.diffusion_llama import DiffusionLlamaHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.gemma4 import Gemma4HuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.llama import LlamaHuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.mistral import MistralHuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.mixtral import MixtralHuggingfaceCheckpointHandler @@ -38,4 +40,5 @@ class AutoGPTHuggingfaceCheckpointHandler( DiffusionLlamaCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, AprielHybridSSMCheckpointFormat.name: AprielHuggingfaceCheckpointHandler, Apriel2TextCheckpointFormat.name: Apriel2HuggingfaceCheckpointHandler, + Gemma4CheckpointFormat.name: Gemma4HuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/conversion/config.py b/fast_llm/models/gpt/conversion/config.py index 240860529..41a0828b6 100644 --- a/fast_llm/models/gpt/conversion/config.py +++ b/fast_llm/models/gpt/conversion/config.py @@ -51,3 +51,7 @@ class AprielHybridSSMCheckpointFormat(GPTHuggingfaceCheckpointFormat): class Apriel2TextCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel2_text" + + +class Gemma4CheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "gemma4" diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py new file mode 100644 index 000000000..d157ed5de --- /dev/null +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -0,0 +1,608 @@ +"""Gemma4 checkpoint format converter.""" + +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import ( + SplitWeightConverter, + WeightConverter, +) +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, ProportionalRotaryConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.common.normalization.config import FixedRMSNormConfig, RMSNormalizationConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.mlp.config import HybridMoEMLPConfig, MLPConfig, MoEMLPConfig +from fast_llm.layers.language_model.config import ( + LanguageModelConfig, + LanguageModelEmbeddingsConfig, + LanguageModelHeadConfig, +) +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.conversion.config import Gemma4CheckpointFormat +from fast_llm.models.gpt.conversion.llama import ( + KeyValueWeightConverter, + LlamaEmbeddingsConverter, + LlamaHeadConverter, + LlamaNormalizationConverter, + MLPLayer2Converter, + QueryWeightConverter, + get_parameter_converter, + get_weight_and_bias_converters, +) +from fast_llm.models.gpt.model import GPTModel +from fast_llm.utils import Assert, safe_merge_dicts + +_SLIDING_ATTENTION = "sliding_attention" +_FULL_ATTENTION = "full_attention" + + +class Gemma4AttentionConverter: + @classmethod + def import_config(cls, config: dict, is_sliding: bool) -> dict: + eps = config["rms_norm_eps"] + if is_sliding: + rope_params = config["rope_parameters"][_SLIDING_ATTENTION] + rotary = {"type": "default", "theta": rope_params["rope_theta"]} + head_size = config["head_dim"] + head_groups = config["num_key_value_heads"] + window_size = config["sliding_window"] + else: + rope_params = config["rope_parameters"][_FULL_ATTENTION] + rotary = { + "type": "proportional", + "theta": rope_params["rope_theta"], + "partial_rotary_factor": rope_params["partial_rotary_factor"], + } + head_size = config["global_head_dim"] + num_global_kv_heads = config.get("num_global_key_value_heads") + head_groups = config["num_key_value_heads"] if num_global_kv_heads is None else num_global_kv_heads + window_size = None + out = { + "heads": config["num_attention_heads"], + "head_groups": head_groups, + "head_size": head_size, + "add_linear_biases": False, + "dropout": config["attention_dropout"], + "softmax_scale_power": 0, + "rotary": rotary, + "query_norm": {"type": "rms_norm", "epsilon": eps}, + "key_norm": {"type": "rms_norm", "epsilon": eps}, + "value_norm": {"type": "fixed_rms_norm", "epsilon": eps}, + } + if window_size is not None: + out["window_size"] = window_size + return out + + @classmethod + def export_config(cls, sliding_config: AttentionConfig, full_config: AttentionConfig) -> dict: + Assert.custom(isinstance, sliding_config, AttentionConfig) + Assert.custom(isinstance, full_config, AttentionConfig) + assert not sliding_config.add_linear_biases + assert isinstance(sliding_config.rotary, DefaultRotaryConfig) + assert isinstance(full_config.rotary, ProportionalRotaryConfig) + Assert.custom(isinstance, sliding_config.query_norm, RMSNormalizationConfig) + Assert.custom(isinstance, sliding_config.key_norm, RMSNormalizationConfig) + Assert.custom(isinstance, sliding_config.value_norm, FixedRMSNormConfig) + eps = sliding_config.query_norm.epsilon + num_global_kv_heads = ( + None if full_config.head_groups == sliding_config.head_groups else full_config.head_groups + ) + return { + "num_attention_heads": sliding_config.heads, + "num_key_value_heads": sliding_config.head_groups, + "head_dim": sliding_config.head_size, + "global_head_dim": full_config.head_size, + "num_global_key_value_heads": num_global_kv_heads, + "attention_bias": False, + "attention_dropout": sliding_config.dropout, + "sliding_window": sliding_config.window_size, + "rms_norm_eps": eps, + "rope_parameters": { + _SLIDING_ATTENTION: { + "rope_type": "default", + "rope_theta": sliding_config.rotary.theta, + }, + _FULL_ATTENTION: { + "rope_type": "proportional", + "rope_theta": full_config.rotary.theta, + "partial_rotary_factor": full_config.rotary.partial_rotary_factor, + }, + }, + } + + @classmethod + def get_converters( + cls, + config: AttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + converters = [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.query", + f"{hf_prefix}.q_proj", + False, + QueryWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), + False, + KeyValueWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.o_proj", + False, + drop_on_export=drop_on_export, + ), + ] + if config.query_norm is not None: + converters += LlamaNormalizationConverter.get_converters( + config.query_norm, + f"{fast_llm_prefix}.query_norm", + f"{hf_prefix}.q_norm", + drop_on_export=drop_on_export, + ) + if config.key_norm is not None: + converters += LlamaNormalizationConverter.get_converters( + config.key_norm, + f"{fast_llm_prefix}.key_norm", + f"{hf_prefix}.k_norm", + drop_on_export=drop_on_export, + ) + # value_norm is FixedRMSNorm — no learnable weight to convert + return converters + + +class Gemma4MLPConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "intermediate_size": config["intermediate_size"], + "add_linear_biases": False, + "activation": ActivationType.from_hf_name(config["hidden_activation"]), + "gated": True, + } + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + Assert.custom(isinstance, config, MLPConfig) + assert config.gated + assert not config.add_linear_biases + return { + "intermediate_size": config.intermediate_size, + "hidden_activation": config.activation.hf_name, + } + + @classmethod + def get_converters( + cls, + config: MLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_1", + (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), + False, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_2", + f"{hf_prefix}.down_proj", + False, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + + +class Gemma4MoEMLPConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "type": "moe", + "intermediate_size": config["moe_intermediate_size"], + "add_linear_biases": False, + "activation": ActivationType.from_hf_name(config["hidden_activation"]), + "gated": True, + "experts": config["num_experts"], + "experts_per_token": config["top_k_experts"], + } + + @classmethod + def export_config(cls, config: MoEMLPConfig) -> dict: + Assert.custom(isinstance, config, MoEMLPConfig) + assert config.gated + assert not config.add_linear_biases + return { + "num_experts": config.experts, + "top_k_experts": config.experts_per_token, + "moe_intermediate_size": config.intermediate_size, + } + + @classmethod + def get_converters( + cls, + config: MoEMLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + converters = [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.router", + f"{hf_prefix}.router.proj", + False, + drop_on_export=drop_on_export, + ), + # gate_up_proj shape [experts, 2*intermediate, hidden] matches Fast-LLM layer_1 + get_parameter_converter( + f"{fast_llm_prefix}.layer_1.weight", + f"{hf_prefix}.experts.gate_up_proj", + drop_on_export=drop_on_export, + ), + # down_proj shape [experts, hidden, intermediate] matches Fast-LLM layer_2 + get_parameter_converter( + f"{fast_llm_prefix}.layer_2.weight", + f"{hf_prefix}.experts.down_proj", + drop_on_export=drop_on_export, + ), + ] + if not drop_on_export: + # Gemma4-specific router parameters without Fast-LLM equivalents + converters += [ + get_parameter_converter((), f"{hf_prefix}.router.scale", drop_on_import=True), + get_parameter_converter((), f"{hf_prefix}.router.per_expert_scale", drop_on_import=True), + ] + return converters + + +class Gemma4HybridMoEMLPConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + def make_norm(): + return {"type": "rms_norm", "epsilon": config["rms_norm_eps"]} + + return { + "type": "hybrid_moe", + "dense": Gemma4MLPConverter.import_config(config), + "routed": Gemma4MoEMLPConverter.import_config(config), + "dense_pre_norm": make_norm(), + "moe_pre_norm": make_norm(), + "dense_post_norm": make_norm(), + "moe_post_norm": make_norm(), + } + + @classmethod + def export_config(cls, config: HybridMoEMLPConfig) -> dict: + Assert.custom(isinstance, config, HybridMoEMLPConfig) + return safe_merge_dicts( + Gemma4MLPConverter.export_config(config.dense), + Gemma4MoEMLPConverter.export_config(config.routed), + {"enable_moe_block": True}, + ) + + @classmethod + def get_converters( + cls, + config: HybridMoEMLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + norm_config = config.dense_pre_norm + return [ + *Gemma4MLPConverter.get_converters( + config.dense, + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.mlp", + drop_on_export=drop_on_export, + ), + *Gemma4MoEMLPConverter.get_converters( + config.routed, + f"{fast_llm_prefix}.routed", + hf_prefix, + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + norm_config, + f"{fast_llm_prefix}.dense_pre_norm", + f"{hf_prefix}.pre_feedforward_layernorm", + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + norm_config, + f"{fast_llm_prefix}.moe_pre_norm", + f"{hf_prefix}.pre_feedforward_layernorm_2", + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + norm_config, + f"{fast_llm_prefix}.dense_post_norm", + f"{hf_prefix}.post_feedforward_layernorm_1", + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + norm_config, + f"{fast_llm_prefix}.moe_post_norm", + f"{hf_prefix}.post_feedforward_layernorm_2", + drop_on_export=drop_on_export, + ), + ] + + +class Gemma4BlockConverter: + @classmethod + def import_config(cls, config: dict, is_sliding: bool) -> dict: + def make_norm(): + return {"type": "rms_norm", "epsilon": config["rms_norm_eps"]} + + out = { + "mixer": Gemma4AttentionConverter.import_config(config, is_sliding), + "normalization": make_norm(), + "post_mixer_normalization": make_norm(), + "post_mlp_normalization": make_norm(), + } + if config.get("enable_moe_block"): + out["mlp"] = Gemma4HybridMoEMLPConverter.import_config(config) + out["pre_mlp_normalization"] = {"type": "none"} + else: + out["mlp"] = Gemma4MLPConverter.import_config(config) + out["pre_mlp_normalization"] = make_norm() + return out + + @classmethod + def export_config(cls, sliding_config: DecoderBlockConfig, full_config: DecoderBlockConfig) -> dict: + Assert.custom(isinstance, sliding_config, DecoderBlockConfig) + norm_config = sliding_config.normalization + Assert.custom(isinstance, norm_config, RMSNormalizationConfig) + is_moe = isinstance(sliding_config.mlp, HybridMoEMLPConfig) + out = safe_merge_dicts( + Gemma4AttentionConverter.export_config(sliding_config.mixer, full_config.mixer), + LlamaNormalizationConverter.export_config(norm_config), + ) + if is_moe: + out = safe_merge_dicts(out, Gemma4HybridMoEMLPConverter.export_config(sliding_config.mlp)) + else: + out = safe_merge_dicts(out, Gemma4MLPConverter.export_config(sliding_config.mlp)) + out["enable_moe_block"] = False + return out + + @classmethod + def get_converters( + cls, + config: DecoderBlockConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + is_moe = isinstance(config.mlp, HybridMoEMLPConfig) + converters = [ + *Gemma4AttentionConverter.get_converters( + config.mixer, + f"{fast_llm_prefix}.mixer", + f"{hf_prefix}.self_attn", + drop_on_export=drop_on_export, + ), + ] + if is_moe: + converters += Gemma4HybridMoEMLPConverter.get_converters( + config.mlp, + f"{fast_llm_prefix}.mlp", + hf_prefix, + drop_on_export=drop_on_export, + ) + else: + converters += Gemma4MLPConverter.get_converters( + config.mlp, + f"{fast_llm_prefix}.mlp", + f"{hf_prefix}.mlp", + drop_on_export=drop_on_export, + ) + converters += LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_2", + f"{hf_prefix}.pre_feedforward_layernorm", + drop_on_export=drop_on_export, + ) + converters += [ + *LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_1", + f"{hf_prefix}.input_layernorm", + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + config.post_mixer_normalization, + f"{fast_llm_prefix}.post_mixer_norm", + f"{hf_prefix}.post_attention_layernorm", + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + config.post_mlp_normalization, + f"{fast_llm_prefix}.post_mlp_norm", + f"{hf_prefix}.post_feedforward_layernorm", + drop_on_export=drop_on_export, + ), + ] + if not drop_on_export: + converters.append(get_parameter_converter((), f"{hf_prefix}.layer_scalar", drop_on_import=True)) + return converters + + +class Gemma4DecoderConverter: + block_converter_class: typing.ClassVar[type[Gemma4BlockConverter]] = Gemma4BlockConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + layer_types = config["layer_types"] + unique_types = list(dict.fromkeys(layer_types)) + blocks = { + layer_type: cls.block_converter_class.import_config(config, layer_type == _SLIDING_ATTENTION) + for layer_type in unique_types + } + return { + "type": "pattern", + "blocks": blocks, + "pattern": layer_types, + "num_blocks": config["num_hidden_layers"], + } + + @classmethod + def export_config(cls, config: PatternBlockSequenceConfig | FixedBlockSequenceConfig) -> dict: + Assert.custom(isinstance, config, PatternBlockSequenceConfig) + Assert.incl(_SLIDING_ATTENTION, config.blocks) + Assert.incl(_FULL_ATTENTION, config.blocks) + return safe_merge_dicts( + cls.block_converter_class.export_config( + config.blocks[_SLIDING_ATTENTION], + config.blocks[_FULL_ATTENTION], + ), + { + "num_hidden_layers": config.num_blocks, + "layer_types": list(config.expanded_pattern), + }, + ) + + @classmethod + def get_converters( + cls, + config: PatternBlockSequenceConfig | FixedBlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + Assert.custom(isinstance, config, PatternBlockSequenceConfig) + converters = [] + for block_index in range(config.num_blocks): + block_config = config.blocks[config.expanded_pattern[block_index]] + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export=drop_on_export, + ) + return converters + + +class Gemma4EmbeddingsConverter(LlamaEmbeddingsConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "vocab_size": config["vocab_size"], + "embedding_scale": config["hidden_size"] ** 0.5, + } + + @classmethod + def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict: + Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig) + assert not config.position_embeddings.enabled + return {"vocab_size": config.vocab_size} + + +class Gemma4HeadConverter(LlamaHeadConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + out = {"normalization": LlamaNormalizationConverter.import_config(config)} + if (softcap := config.get("final_logit_softcapping")) is not None: + out["final_logit_softcap"] = softcap + return out + + @classmethod + def export_config(cls, config: LanguageModelHeadConfig) -> dict: + out = LlamaNormalizationConverter.export_config(config.normalization) + if config.final_logit_softcap is not None: + out["final_logit_softcapping"] = config.final_logit_softcap + return out + + @classmethod + def get_converters( + cls, + config: LanguageModelConfig, + exported_config: dict, + ) -> list[WeightConverter]: + return [ + *LlamaNormalizationConverter.get_converters( + config.head.normalization, + "head.final_norm", + "model.norm", + ), + get_parameter_converter( + "head.output_weights", + "lm_head.weight", + drop_on_import=exported_config["tie_word_embeddings"], + drop_on_export=exported_config["tie_word_embeddings"], + ), + ] + + +class Gemma4BaseModelConverter: + decoder_converter_class: typing.ClassVar[type[Gemma4DecoderConverter]] = Gemma4DecoderConverter + embeddings_converter_class: typing.ClassVar[type[Gemma4EmbeddingsConverter]] = Gemma4EmbeddingsConverter + head_converter_class: typing.ClassVar[type[Gemma4HeadConverter]] = Gemma4HeadConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "embeddings": cls.embeddings_converter_class.import_config(config), + "decoder": cls.decoder_converter_class.import_config(config), + "head": cls.head_converter_class.import_config(config), + "hidden_size": config["hidden_size"], + "tied_embedding_weight": config["tie_word_embeddings"], + } + + @classmethod + def export_config(cls, config: GPTBaseModelConfig) -> dict: + Assert.custom(isinstance, config, GPTBaseModelConfig) + return safe_merge_dicts( + cls.embeddings_converter_class.export_config(config.embeddings), + cls.decoder_converter_class.export_config(config.decoder), + cls.head_converter_class.export_config(config.head), + { + "tie_word_embeddings": config.tied_embedding_weight, + "hidden_size": config.hidden_size, + # TODO: Implement Per-Layer Embeddings (PLE). Gemma4TextConfig defaults to 256; + # explicitly zero to disable the feature in the exported model until Fast-LLM + # supports it natively. + "hidden_size_per_layer_input": 0, + }, + ) + + @classmethod + def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + return [ + *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), + *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), + *cls.head_converter_class.get_converters(config, exported_config), + ] + + +class Gemma4HuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + _model: GPTModel + _model_class: typing.ClassVar = GPTModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = Gemma4CheckpointFormat + architecture: typing.ClassVar[str] = "Gemma4ForCausalLM" + base_model_converter_class: typing.ClassVar[type[Gemma4BaseModelConverter]] = Gemma4BaseModelConverter + + @classmethod + def get_huggingface_model_type(cls) -> str: + return "gemma4_text" + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.Gemma4TextConfig diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 9cccf54cd..e57210bfb 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -20,6 +20,7 @@ Apriel2TextCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, + Gemma4CheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -963,6 +964,74 @@ def update_and_add_testing_config( ) +# Use init_1 for extra norms to keep the residual stream at small scale (~0.35). +# Default ones-init would normalize small layer outputs to unit scale before the residual add, +# growing the stream to ~1.5 per block and causing bf16 absolute errors to exceed compare_factor=2. +# query_norm/key_norm init_1 also keeps attention logits small (softmax_scale_power=0 otherwise +# amplifies bf16 relative error 3x). ParameterConfig.initialization is FieldHint.feature so +# this is invisible to the architecture comparison in test_load_pretrained. +_gemma4_block_overrides = { + "post_mixer_normalization": {"type": "rms_norm", "weight": init_1}, + "post_mlp_normalization": {"type": "rms_norm", "weight": init_1}, + "pre_mlp_normalization": {"type": "rms_norm", "weight": init_1}, +} +_gemma4_mixer_overrides = { + "softmax_scale_power": 0, + "query_norm": {"type": "rms_norm", "weight": init_1}, + "key_norm": {"type": "rms_norm", "weight": init_1}, + "value_norm": {"type": "fixed_rms_norm"}, +} + +update_and_add_testing_config( + # Tests Gemma4 converter: pattern decoder with alternating sliding/full attention, + # per-head norms (q/k/v), post-attention and post-MLP norms, embedding scale. + "llama", + "gemma4", + updates={ + ("model", "base_model", "tied_embedding_weight"): True, + ("model", "base_model", "embeddings", "embedding_scale"): 16.0, # sqrt(hidden_size=256); must match converter + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "sliding_attention": { + **copy.deepcopy(_llama_block), + **_gemma4_block_overrides, + "mixer": { + **copy.deepcopy(_llama_block["mixer"]), + **_gemma4_mixer_overrides, + "window_size": 128, + }, + }, + "full_attention": { + **copy.deepcopy(_llama_block), + **_gemma4_block_overrides, + "mixer": { + **copy.deepcopy(_llama_block["mixer"]), + **_gemma4_mixer_overrides, + "rotary": {"type": "proportional", "partial_rotary_factor": 0.25}, + }, + }, + }, + "pattern": ["sliding_attention", "full_attention"], + "num_blocks": 2, + }, + }, + megatron_args=None, + checkpoint_format=Gemma4CheckpointFormat, + compare_factor=5.0, # init_1 on post_mlp_norm makes its gradient tiny (~5e-6), hitting the fp16 rms_eps floor + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + skip_tests=("sdp", "ms"), + requires_cuda=False, +) + + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: models = request.config.getoption("--models") From 7da95c31256297999456127f25b5c4722debfd15 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 23:33:16 -0400 Subject: [PATCH 09/14] Improve Gemma4 converter: shared_key_value, MoE weight shapes, roundtrip test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Map HF attention_k_eq_v=True to AttentionConfig.shared_key_value=True for full-attention layers in the 26B-A4B model (K projection is reused as V; only a single k_proj weight exists, no v_proj) - Add Gemma4MoELayer1Converter / Gemma4MoELayer2Converter to correctly reshape batched expert weights: gate_up_proj [E,2I,H] ↔ [E*2I,H] and down_proj [E,H,I] ↔ [E*I,H] (permute+reshape) - Export use_bidirectional_attention=None (text-only; vision tokens not supported) - Add test_hf_roundtrip[gemma4] using google/gemma-4-26B-A4B config Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/models/gpt/conversion/gemma4.py | 70 ++++++++++++++++++++---- tests/models/test_hf_roundtrip.py | 49 +++++++++++++++++ 2 files changed, 109 insertions(+), 10 deletions(-) diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index d157ed5de..3a30f83fe 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -39,6 +39,38 @@ _FULL_ATTENTION = "full_attention" +class Gemma4MoELayer1Converter(WeightConverter): + """Converts batched gate_up_proj [experts, 2*intermediate, hidden] ↔ Fast-LLM layer_1 [experts*2*intermediate, hidden].""" + + _config: MoEMLPConfig + + def export_weight(self, weight): + (layer_1,) = weight + w = layer_1[:] + return (w.reshape(self._config.experts, -1, w.shape[-1]),) + + def import_weight(self, weight): + (gate_up_proj,) = weight + w = gate_up_proj[:] + return (w.reshape(-1, w.shape[-1]),) + + +class Gemma4MoELayer2Converter(WeightConverter): + """Converts batched down_proj [experts, hidden, intermediate] ↔ Fast-LLM layer_2 [experts*intermediate, hidden].""" + + _config: MoEMLPConfig + + def export_weight(self, weight): + (layer_2,) = weight + w = layer_2[:] + return (w.reshape(self._config.experts, -1, w.shape[-1]).permute(0, 2, 1).contiguous(),) + + def import_weight(self, weight): + (down_proj,) = weight + w = down_proj[:] + return (w.permute(0, 2, 1).reshape(-1, w.shape[1]).contiguous(),) + + class Gemma4AttentionConverter: @classmethod def import_config(cls, config: dict, is_sliding: bool) -> dict: @@ -72,6 +104,8 @@ def import_config(cls, config: dict, is_sliding: bool) -> dict: "key_norm": {"type": "rms_norm", "epsilon": eps}, "value_norm": {"type": "fixed_rms_norm", "epsilon": eps}, } + if not is_sliding and config.get("attention_k_eq_v", False): + out["shared_key_value"] = True if window_size is not None: out["window_size"] = window_size return out @@ -100,6 +134,7 @@ def export_config(cls, sliding_config: AttentionConfig, full_config: AttentionCo "attention_dropout": sliding_config.dropout, "sliding_window": sliding_config.window_size, "rms_norm_eps": eps, + "attention_k_eq_v": full_config.shared_key_value, "rope_parameters": { _SLIDING_ATTENTION: { "rope_type": "default", @@ -121,23 +156,33 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - converters = [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.query", - f"{hf_prefix}.q_proj", + if config.shared_key_value: + # K=V: single k_proj reused as value; no v_proj in HF + kv_converters = get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + f"{hf_prefix}.k_proj", False, - QueryWeightConverter, - config, drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( + ) + else: + kv_converters = get_weight_and_bias_converters( f"{fast_llm_prefix}.key_value", (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), False, KeyValueWeightConverter, config, drop_on_export=drop_on_export, + ) + converters = [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.query", + f"{hf_prefix}.q_proj", + False, + QueryWeightConverter, + config, + drop_on_export=drop_on_export, ), + *kv_converters, *get_weight_and_bias_converters( f"{fast_llm_prefix}.dense", f"{hf_prefix}.o_proj", @@ -248,16 +293,18 @@ def get_converters( False, drop_on_export=drop_on_export, ), - # gate_up_proj shape [experts, 2*intermediate, hidden] matches Fast-LLM layer_1 get_parameter_converter( f"{fast_llm_prefix}.layer_1.weight", f"{hf_prefix}.experts.gate_up_proj", + Gemma4MoELayer1Converter, + config, drop_on_export=drop_on_export, ), - # down_proj shape [experts, hidden, intermediate] matches Fast-LLM layer_2 get_parameter_converter( f"{fast_llm_prefix}.layer_2.weight", f"{hf_prefix}.experts.down_proj", + Gemma4MoELayer2Converter, + config, drop_on_export=drop_on_export, ), ] @@ -578,6 +625,9 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict: # explicitly zero to disable the feature in the exported model until Fast-LLM # supports it natively. "hidden_size_per_layer_input": 0, + # Fast-LLM is text-only; bidirectional attention (used for vision tokens in the + # multimodal model) is not implemented. + "use_bidirectional_attention": None, }, ) diff --git a/tests/models/test_hf_roundtrip.py b/tests/models/test_hf_roundtrip.py index 3c472086b..f4baf7698 100644 --- a/tests/models/test_hf_roundtrip.py +++ b/tests/models/test_hf_roundtrip.py @@ -15,6 +15,8 @@ import torch from transformers import ( AutoConfig, + Gemma4ForCausalLM, + Gemma4TextConfig, LlamaConfig, LlamaForCausalLM, MistralConfig, @@ -37,6 +39,7 @@ from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import ( Apriel2TextCheckpointFormat, + Gemma4CheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -101,6 +104,25 @@ def make_model(self) -> PreTrainedModel: return self.model_class(self.config_class(**converted_config)) +@dataclasses.dataclass(frozen=True) +class Gemma4RoundtripCase(HFRoundtripCase): + """Gemma4: apply dim overrides directly without deriving head_dim from hidden_size.""" + + def make_model(self) -> PreTrainedModel: + config = self.config_class.from_pretrained(self.hf_model_name) + for key, value in self.dim_overrides.items(): + setattr(config, key, value) + config.max_position_embeddings = self.max_position_embeddings + if getattr(config, "layer_types", None) is not None: + n = config.num_hidden_layers + lt = config.layer_types + config.layer_types = (lt * ((n // len(lt)) + 1))[:n] + for key in self.delete_config_keys: + if hasattr(config, key): + delattr(config, key) + return self.model_class(config) + + _TINY_DIMS = { "hidden_size": 64, "num_attention_heads": 4, @@ -212,6 +234,33 @@ def make_model(self) -> PreTrainedModel: }, }, ), + Gemma4RoundtripCase( + name="gemma4", + hf_model_name="google/gemma-4-26B-A4B", + checkpoint_format=Gemma4CheckpointFormat, + model_class=Gemma4ForCausalLM, + config_class=Gemma4TextConfig, + dim_overrides={ + "hidden_size": 256, + "num_hidden_layers": 6, # 5 sliding + 1 full from real layer_types pattern + "num_attention_heads": 8, + "num_key_value_heads": 4, + "num_global_key_value_heads": 2, + "head_dim": 32, + "global_head_dim": 64, # real model has 2:1 ratio (256:512) + "intermediate_size": 256, + "moe_intermediate_size": 128, + "num_experts": 4, + "top_k_experts": 2, + "vocab_size": 384, + "hidden_size_per_layer_input": 0, + # use_bidirectional_attention="vision" in the real model is for multimodal vision tokens; + # Fast-LLM is text-only so the converter exports None — reset source to match. + "use_bidirectional_attention": None, + }, + max_position_embeddings=131072, # Gemma4TextConfig default; converter does not export this + delete_config_keys=("dtype",), # "bfloat16" in real config; metadata not preserved by converter + ), ] From b6e32b430ef82288532c858803b9c0305c27fc54 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 30 Apr 2026 00:19:34 -0400 Subject: [PATCH 10/14] Fix test_rotary and test_mlp failures test_rotary: triton_rotary_ modifies query in-place, so clone before calling forward to avoid feeding the already-rotated tensor to the reference implementation (which caused a double-rotation mismatch). test_mlp: increase _NUM_TOKENS/_HIDDEN_SIZE/_INTERMEDIATE_SIZE from 16/64/32 to 128/128/128 so dimensions satisfy the block_size_row=128, block_size_col=128 compile-time assertions in output_sparse_matmul_kernel when FAST_LLM_SKIP_TRITON_AUTOTUNE is set. Co-Authored-By: Claude Sonnet 4.6 --- tests/layers/test_mlp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/layers/test_mlp.py b/tests/layers/test_mlp.py index 8ca5a3d66..70108224b 100644 --- a/tests/layers/test_mlp.py +++ b/tests/layers/test_mlp.py @@ -12,9 +12,9 @@ from fast_llm.utils import Assert from tests.utils.utils import get_stage -_NUM_TOKENS = 16 -_HIDDEN_SIZE = 64 -_INTERMEDIATE_SIZE = 32 +_NUM_TOKENS = 128 +_HIDDEN_SIZE = 128 +_INTERMEDIATE_SIZE = 128 _EXPERTS = 4 _NORM = {"type": "rms_norm"} From e887a94f2dc1e5ffcbcbdf2cdcbab57d8cada07e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 07:06:12 -0400 Subject: [PATCH 11/14] Add output_scale to DecoderBlock for Gemma 4 layer_scalar Gemma 4 multiplies the block output by a per-layer scalar (HF stores it as a non-trained `register_buffer("layer_scalar", ones(1))`). Expose this as an `OptionalParameterConfig` field on `DecoderBlock`, disabled by default. The Gemma 4 converter enables it with `lr_scale=0` to match HF's non-trained semantics; the test fixture mirrors that so frozen-parameter packing produces a consistent shard layout across conversion paths. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/decoder/block.py | 12 +++++++++++- fast_llm/layers/decoder/config.py | 7 ++++++- fast_llm/models/gpt/conversion/gemma4.py | 12 ++++++++++-- tests/layers/test_decoder_block.py | 14 +++++++++++++- tests/utils/model_configs.py | 4 ++++ 5 files changed, 44 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index d41283c9d..4fbdd945a 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -6,7 +6,8 @@ from fast_llm.core.distributed import ReduceOp, all_reduce, set_generator from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig -from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.config_utils.initialization import init_ones_ +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.utils import AuxiliaryLoss @@ -123,6 +124,13 @@ def __init__( return_bias=True, ) + self.output_scale = self._config.output_scale.get_parameter( + (scalar_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + def setup(self, distributed: Distributed) -> None: super().setup(distributed) self.mixer.setup(distributed) @@ -175,6 +183,8 @@ def forward( hidden_states = self.post_mlp_norm(hidden_states) with set_generator(generator): hidden_states = self._bias_dropout_add(hidden_states, bias, input_) + if self.output_scale is not None: + hidden_states = hidden_states * self.output_scale self._debug(hidden_states, None, hidden_dims, kwargs) if self._return_input: diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index f949ec714..1cfa1b1a4 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -3,7 +3,7 @@ import warnings from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.config_utils.parameter import combine_lr_scales +from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import _BIG_PRIMES, DistributedConfig from fast_llm.layers.block.config import BlockConfig, BlockKwargs @@ -239,6 +239,11 @@ class DecoderBlockConfig(BlockConfig): desc="Optional normalization applied to the MLP output before the residual add. Set to `{type: rms_norm}` to enable.", hint=FieldHint.architecture, ) + output_scale: OptionalParameterConfig = Field( + desc="Optional learnable scalar multiplied into the block output (after the MLP residual add)." + " Disabled by default; used by Gemma 4.", + hint=FieldHint.architecture, + ) # TODO: Review names dropout: float = Field( default=0.0, diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 3a30f83fe..6b23225ea 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -402,6 +402,9 @@ def make_norm(): "normalization": make_norm(), "post_mixer_normalization": make_norm(), "post_mlp_normalization": make_norm(), + # HF stores `layer_scalar` as a non-trained buffer (`register_buffer`); preserve its value + # but freeze it on our side so finetuning matches HF training dynamics. + "output_scale": {"enabled": True, "lr_scale": 0}, } if config.get("enable_moe_block"): out["mlp"] = Gemma4HybridMoEMLPConverter.import_config(config) @@ -485,8 +488,13 @@ def get_converters( drop_on_export=drop_on_export, ), ] - if not drop_on_export: - converters.append(get_parameter_converter((), f"{hf_prefix}.layer_scalar", drop_on_import=True)) + converters.append( + get_parameter_converter( + f"{fast_llm_prefix}.output_scale", + f"{hf_prefix}.layer_scalar", + drop_on_export=drop_on_export, + ) + ) return converters diff --git a/tests/layers/test_decoder_block.py b/tests/layers/test_decoder_block.py index 37302f79d..2dc0e25cc 100644 --- a/tests/layers/test_decoder_block.py +++ b/tests/layers/test_decoder_block.py @@ -25,6 +25,9 @@ class PostNormTestConfig: name: str post_mixer_norm: bool = False post_mlp_norm: bool = False + # If set, enable `output_scale` and override its initialized value with this float. + # Picking a non-unit value makes the test sensitive to the multiplication (the default init is 1.0). + output_scale: float | None = None def get_block_config(self) -> DecoderBlockConfig: config_dict: dict = { @@ -45,6 +48,8 @@ def get_block_config(self) -> DecoderBlockConfig: config_dict["post_mixer_normalization"] = {"type": "rms_norm"} if self.post_mlp_norm: config_dict["post_mlp_normalization"] = {"type": "rms_norm"} + if self.output_scale is not None: + config_dict["output_scale"] = {"enabled": True} return DecoderBlockConfig.from_dict(config_dict) @functools.cached_property @@ -67,7 +72,10 @@ def expected_output(self, block: DecoderBlock, input_: torch.Tensor, kwargs: dic mlp_hidden = block.post_mlp_norm(mlp_hidden) if mlp_bias is not None: mlp_hidden = mlp_hidden + mlp_bias - return after_mixer + mlp_hidden + output = after_mixer + mlp_hidden + if self.output_scale is not None: + output = output * self.output_scale + return output _base_post_norm_cases = [ @@ -75,6 +83,7 @@ def expected_output(self, block: DecoderBlock, input_: torch.Tensor, kwargs: dic ("post_mixer_norm", {"post_mixer_norm": True}), ("post_mlp_norm", {"post_mlp_norm": True}), ("both_post_norms", {"post_mixer_norm": True, "post_mlp_norm": True}), + ("output_scale", {"output_scale": 2.5}), ] _post_norm_test_configs = [PostNormTestConfig(name=name, **kwargs) for name, kwargs in _base_post_norm_cases] @@ -95,6 +104,9 @@ def test_post_norms(test_config: PostNormTestConfig): block.eval() device = distributed.device + if test_config.output_scale is not None: + with torch.no_grad(): + block.output_scale.fill_(test_config.output_scale) input_ = torch.randn(_NUM_TOKENS, _HIDDEN_SIZE, device=device) token_dim = TensorDim("token", _NUM_TOKENS) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index e57210bfb..d35bab14e 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -974,6 +974,10 @@ def update_and_add_testing_config( "post_mixer_normalization": {"type": "rms_norm", "weight": init_1}, "post_mlp_normalization": {"type": "rms_norm", "weight": init_1}, "pre_mlp_normalization": {"type": "rms_norm", "weight": init_1}, + # Match the gemma4 converter's import_config which freezes output_scale (HF stores it as a non-trained + # buffer); without lr_scale=0 here, the converter round-trip would produce a different shard layout + # than the original because frozen parameters are packed at the end of the stage. + "output_scale": {"enabled": True, "lr_scale": 0}, } _gemma4_mixer_overrides = { "softmax_scale_power": 0, From 4476b62faa7dec77e834a4540a1c886771554926 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 07:52:10 -0400 Subject: [PATCH 12/14] Fold output_scale into bias_dropout_add; fix subprocess interpreter Apply `output_scale` inside `_bias_dropout_add` so torch.compile fuses the multiply with the residual add. Trim associated comments and the field-desc blurb. Also fix `tests/test_config.py::test_validate_*_without_import` to use `sys.executable` instead of literal `python3`. The convention check is unchanged (still pre-imports only yaml/requests/packaging then strips site-packages); the previous form failed on systems where `python3` resolves to a python without yaml installed (e.g. macOS Homebrew). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/decoder/block.py | 15 ++++++++++----- fast_llm/layers/decoder/config.py | 3 +-- fast_llm/models/gpt/conversion/gemma4.py | 3 +-- tests/layers/test_decoder_block.py | 2 -- tests/test_config.py | 3 ++- tests/utils/model_configs.py | 5 ++--- 6 files changed, 16 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 4fbdd945a..5caf24c2a 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -138,11 +138,18 @@ def setup(self, distributed: Distributed) -> None: @torch.compile def _bias_dropout_add( - self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor + self, + input_: torch.Tensor, + bias: torch.Tensor | None, + residual: torch.Tensor, + output_scale: torch.Tensor | None = None, ) -> torch.Tensor: if bias is not None: input_ = input_ + bias - return residual + torch.dropout(input_, self._config.dropout, self.training) + output = residual + torch.dropout(input_, self._config.dropout, self.training) + if output_scale is not None: + output = output * output_scale + return output def forward( self, @@ -182,9 +189,7 @@ def forward( if self.post_mlp_norm is not None: hidden_states = self.post_mlp_norm(hidden_states) with set_generator(generator): - hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - if self.output_scale is not None: - hidden_states = hidden_states * self.output_scale + hidden_states = self._bias_dropout_add(hidden_states, bias, input_, self.output_scale) self._debug(hidden_states, None, hidden_dims, kwargs) if self._return_input: diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 1cfa1b1a4..0f11dec26 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -240,8 +240,7 @@ class DecoderBlockConfig(BlockConfig): hint=FieldHint.architecture, ) output_scale: OptionalParameterConfig = Field( - desc="Optional learnable scalar multiplied into the block output (after the MLP residual add)." - " Disabled by default; used by Gemma 4.", + desc="Optional learnable scalar multiplied into the block output (after the MLP residual add).", hint=FieldHint.architecture, ) # TODO: Review names diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 6b23225ea..166b82f1f 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -402,8 +402,7 @@ def make_norm(): "normalization": make_norm(), "post_mixer_normalization": make_norm(), "post_mlp_normalization": make_norm(), - # HF stores `layer_scalar` as a non-trained buffer (`register_buffer`); preserve its value - # but freeze it on our side so finetuning matches HF training dynamics. + # HF stores `layer_scalar` as a non-trained buffer; freeze on our side to match. "output_scale": {"enabled": True, "lr_scale": 0}, } if config.get("enable_moe_block"): diff --git a/tests/layers/test_decoder_block.py b/tests/layers/test_decoder_block.py index 2dc0e25cc..a3b778293 100644 --- a/tests/layers/test_decoder_block.py +++ b/tests/layers/test_decoder_block.py @@ -25,8 +25,6 @@ class PostNormTestConfig: name: str post_mixer_norm: bool = False post_mlp_norm: bool = False - # If set, enable `output_scale` and override its initialized value with this float. - # Picking a non-unit value makes the test sensitive to the multiplication (the default init is 1.0). output_scale: float | None = None def get_block_config(self) -> DecoderBlockConfig: diff --git a/tests/test_config.py b/tests/test_config.py index 792eab077..3753f75d7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,7 @@ import collections import pathlib import subprocess +import sys import pytest import yaml @@ -18,7 +19,7 @@ def run_without_import(cmd: str): # Run the test in a separate process since lots of things are already imported in this one. repo_path = pathlib.Path(__file__).parents[1].resolve() command = [ - "python3", + sys.executable, "-c", "\n".join( [ diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index d35bab14e..6a9b18268 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -974,9 +974,8 @@ def update_and_add_testing_config( "post_mixer_normalization": {"type": "rms_norm", "weight": init_1}, "post_mlp_normalization": {"type": "rms_norm", "weight": init_1}, "pre_mlp_normalization": {"type": "rms_norm", "weight": init_1}, - # Match the gemma4 converter's import_config which freezes output_scale (HF stores it as a non-trained - # buffer); without lr_scale=0 here, the converter round-trip would produce a different shard layout - # than the original because frozen parameters are packed at the end of the stage. + # Must match the gemma4 converter's lr_scale=0 — frozen params are packed at the end of the stage, + # so a mismatch produces a shifted shard layout that fails round-trip. "output_scale": {"enabled": True, "lr_scale": 0}, } _gemma4_mixer_overrides = { From 54b37b59ac9dfcbce969d4ea922c074189051d24 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 08:52:38 -0400 Subject: [PATCH 13/14] Move pre/post norm to MLPBaseConfig; add MoE router preprocessing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Promote `pre_norm`/`post_norm` to `MLPBaseConfig` so all MLP variants (`MLP`, `MixtureOfExpertMLP`, `HybridMoEMLP`) carry their own input/output norms uniformly. Drop the redundant `dense_pre_norm`/`dense_post_norm`/ `moe_pre_norm`/`moe_post_norm` from `HybridMoEMLPConfig` — those are now expressed via the inner `dense.pre_norm`/`routed.pre_norm`/etc., with optional wrapper-level pre/post norms shared across both branches. Add Gemma-style router preprocessing to `MoEMLPConfig`: `router_normalization` (typically `fixed_rms_norm`), `router_scale` (`OptionalParameterConfig`, learnable per-feature), and `router_input_scale` (constant scalar; set to `hidden_size ** -0.5` for Gemma 4). The router runs on the raw input independently of `pre_norm`, which now applies only to the expert path. The two router multiplies are fused via `@torch.compile`. Wire the Gemma 4 converter to import `router.scale`, set `router_input_scale` from `hidden_size`, and configure per-branch norms instead of the wrapper norms. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/decoder/config.py | 11 +++ fast_llm/layers/decoder/mlp/config.py | 36 ++++---- .../layers/decoder/mlp/mixture_of_experts.py | 75 ++++++++++------- fast_llm/layers/decoder/mlp/mlp.py | 17 ++++ fast_llm/models/gpt/conversion/gemma4.py | 67 ++++++++------- tests/layers/test_mlp.py | 83 +++++++++++-------- 6 files changed, 177 insertions(+), 112 deletions(-) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 0f11dec26..1c0c10c87 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -60,6 +60,17 @@ class MLPBaseConfig(BlockWithBiasConfig): _abstract = True + pre_norm: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the MLP input.", + hint=FieldHint.architecture, + ) + post_norm: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the MLP output.", + hint=FieldHint.architecture, + ) + def get_layer( self, distributed_config: DistributedConfig, diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 0abdfe70f..cd022f407 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -3,6 +3,7 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.config_utils.parameter import OptionalParameterConfig from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig from fast_llm.layers.common.normalization.config import NormalizationConfig @@ -98,6 +99,21 @@ class MoEMLPConfig(MLPConfig): desc="Configuration for the MoE router.", hint=FieldHint.feature, ) + router_normalization: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the router input (independent of `pre_norm`, which goes to experts).", + hint=FieldHint.architecture, + ) + router_scale: OptionalParameterConfig = Field( + desc="Optional learnable per-feature scale applied to the router input after `router_normalization`.", + hint=FieldHint.architecture, + ) + router_input_scale: float = Field( + default=1.0, + desc="Constant multiplied into the router input after `router_normalization` and `router_scale`." + " Set to `hidden_size ** -0.5` for Gemma-style routing.", + hint=FieldHint.architecture, + ) experts: int = Field( default=2, desc="Number of MLP experts in a Mixture of Expert (MoE) model", @@ -181,26 +197,6 @@ class HybridMoEMLPConfig(MLPBaseConfig): desc="Configuration for the top-K routed expert MLP.", hint=FieldHint.architecture, ) - dense_pre_norm: NormalizationConfig | None = Field( - default=None, - desc="Optional normalization applied to the dense MLP input.", - hint=FieldHint.architecture, - ) - dense_post_norm: NormalizationConfig | None = Field( - default=None, - desc="Optional normalization applied to the dense MLP output before summing.", - hint=FieldHint.architecture, - ) - moe_pre_norm: NormalizationConfig | None = Field( - default=None, - desc="Optional normalization applied to the routed MLP input.", - hint=FieldHint.architecture, - ) - moe_post_norm: NormalizationConfig | None = Field( - default=None, - desc="Optional normalization applied to the routed MLP output before summing.", - hint=FieldHint.architecture, - ) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 943a971b8..222b2d05f 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import ProcessGroup, set_generator from fast_llm.core.ops import gather_op from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig -from fast_llm.engine.config_utils.initialization import init_normal_ +from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.triton import triton_available @@ -80,6 +80,18 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) + self.router_normalization = ( + self._config.router_normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if self._config.router_normalization is not None + else None + ) + self.router_scale = self._config.router_scale.get_parameter( + (self._hidden_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self._router_input_scale = self._config.router_input_scale implementation = self._config.implementation if implementation == MoEImplementation.auto: implementation = MoEImplementation.dropless if triton_available else MoEImplementation.looped @@ -100,13 +112,25 @@ def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), ) + @torch.compile + def _scale_router_input(self, x: torch.Tensor, scale: torch.Tensor | None, input_scale: float) -> torch.Tensor: + if scale is not None: + x = x * scale + if input_scale != 1.0: + x = x * input_scale + return x + def _forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> tuple[torch.Tensor, None]: if isinstance(input_, TensorMeta): return TensorMeta.from_dims(input_.dims[:-1] + (self._output_dim,), "MLP output"), None hidden_states = input_.flatten(0, -2) - logits = self.router(hidden_states) + router_input = ( + self.router_normalization(hidden_states) if self.router_normalization is not None else hidden_states + ) + router_input = self._scale_router_input(router_input, self.router_scale, self._router_input_scale) + logits = self.router(router_input) hidden_token_dim = kwargs[BlockKwargs.hidden_token_dim] logit_dims = (hidden_token_dim, self._top_expert_dim) self._debug(logits, "Router logits", logit_dims, kwargs) @@ -139,7 +163,10 @@ def _forward( self._debug(scores, "router_scores", logit_dims, kwargs) self._debug(top_experts, "router_top_experts", logit_dims, kwargs) - out = self._mlp_forward(hidden_states, scores, top_experts).view_as(input_) # noqa + expert_input = self.pre_norm(hidden_states) if self.pre_norm is not None else hidden_states + out = self._mlp_forward(expert_input, scores, top_experts).view_as(input_) # noqa + if self.post_norm is not None: + out = self.post_norm(out) self._debug(out, None, (hidden_token_dim, self._hidden_dim), kwargs) return out, None @@ -302,24 +329,14 @@ def __init__( self.routed = config.routed.get_layer( distributed_config, hidden_dim, output_dim=output_dim, lr_scale=lr_scale, peft=peft, return_bias=True ) - self.dense_pre_norm = ( - config.dense_pre_norm.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - if config.dense_pre_norm is not None - else None - ) - self.dense_post_norm = ( - config.dense_post_norm.get_layer(self._output_dim, lr_scale=self._lr_scale, peft=self._peft) - if config.dense_post_norm is not None - else None - ) - self.moe_pre_norm = ( - config.moe_pre_norm.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - if config.moe_pre_norm is not None + self.pre_norm = ( + config.pre_norm.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if config.pre_norm is not None else None ) - self.moe_post_norm = ( - config.moe_post_norm.get_layer(self._output_dim, lr_scale=self._lr_scale, peft=self._peft) - if config.moe_post_norm is not None + self.post_norm = ( + config.post_norm.get_layer(self._output_dim, lr_scale=self._lr_scale, peft=self._peft) + if config.post_norm is not None else None ) @@ -342,15 +359,17 @@ def _forward( ), None, ) - dense_input = self.dense_pre_norm(input_) if self.dense_pre_norm is not None else input_ - moe_input = self.moe_pre_norm(input_) if self.moe_pre_norm is not None else input_ - dense_out, dense_bias = self.dense(dense_input, kwargs, losses, metrics) - routed_out, _ = self.routed(moe_input, kwargs, losses, metrics) - if self.dense_post_norm is not None: - dense_out = self.dense_post_norm(dense_out) - if self.moe_post_norm is not None: - routed_out = self.moe_post_norm(routed_out) - return dense_out + routed_out, dense_bias + if self.pre_norm is not None: + input_ = self.pre_norm(input_) + dense_out, dense_bias = self.dense(input_, kwargs, losses, metrics) + routed_out, _ = self.routed(input_, kwargs, losses, metrics) + out = dense_out + routed_out + if self.post_norm is not None: + if dense_bias is not None: + out = out + dense_bias + dense_bias = None + out = self.post_norm(out) + return out, dense_bias def get_loss_definitions(self) -> list[LossDef]: return self.routed.get_loss_definitions() diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 80599da97..504e26ac5 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -42,6 +42,16 @@ def __init__( self._output_dim = self._hidden_dim if output_dim is None else output_dim self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, self._intermediate_2_dim = self._get_intermediate_dims() + self.pre_norm = ( + self._config.pre_norm.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if self._config.pre_norm is not None + else None + ) + self.post_norm = ( + self._config.post_norm.get_layer(self._output_dim, lr_scale=self._lr_scale, peft=self._peft) + if self._config.post_norm is not None + else None + ) self._activation_fn = ( triton_mlp_activation_autograd if TritonConfig.enabled(torch.device("cuda")) else torch_mlp_activation @@ -116,6 +126,8 @@ def _forward( ), None, ) + if self.pre_norm is not None: + input_ = self.pre_norm(input_) out = mlp_autograd( input_, None, @@ -132,6 +144,11 @@ def _forward( transposed_layer_2_weight=self.layer_2.transposed_weight, ) bias = self.layer_2.bias if self._parallel_dim.group else None + if self.post_norm is not None: + if bias is not None: + out = out + bias + bias = None + out = self.post_norm(out) # Use None for dims when output_dim differs from hidden_dim (e.g., adapter projections) # to let _debug infer dims from actual tensor shape self._debug(out, None, (kwargs.get(BlockKwargs.hidden_token_dim), self._hidden_dim), kwargs, bias=bias) diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 166b82f1f..317be2b76 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -210,13 +210,18 @@ def get_converters( class Gemma4MLPConverter: @classmethod - def import_config(cls, config: dict) -> dict: - return { + def import_config(cls, config: dict, with_norms: bool = False) -> dict: + out = { "intermediate_size": config["intermediate_size"], "add_linear_biases": False, "activation": ActivationType.from_hf_name(config["hidden_activation"]), "gated": True, } + if with_norms: + eps = config["rms_norm_eps"] + out["pre_norm"] = {"type": "rms_norm", "epsilon": eps} + out["post_norm"] = {"type": "rms_norm", "epsilon": eps} + return out @classmethod def export_config(cls, config: MLPConfig) -> dict: @@ -257,6 +262,7 @@ def get_converters( class Gemma4MoEMLPConverter: @classmethod def import_config(cls, config: dict) -> dict: + eps = config["rms_norm_eps"] return { "type": "moe", "intermediate_size": config["moe_intermediate_size"], @@ -265,6 +271,11 @@ def import_config(cls, config: dict) -> dict: "gated": True, "experts": config["num_experts"], "experts_per_token": config["top_k_experts"], + "pre_norm": {"type": "rms_norm", "epsilon": eps}, + "post_norm": {"type": "rms_norm", "epsilon": eps}, + "router_normalization": {"type": "fixed_rms_norm", "epsilon": eps}, + "router_scale": {"enabled": True}, + "router_input_scale": config["hidden_size"] ** -0.5, } @classmethod @@ -293,6 +304,11 @@ def get_converters( False, drop_on_export=drop_on_export, ), + get_parameter_converter( + f"{fast_llm_prefix}.router_scale", + f"{hf_prefix}.router.scale", + drop_on_export=drop_on_export, + ), get_parameter_converter( f"{fast_llm_prefix}.layer_1.weight", f"{hf_prefix}.experts.gate_up_proj", @@ -308,10 +324,23 @@ def get_converters( drop_on_export=drop_on_export, ), ] + if config.pre_norm is not None: + converters += LlamaNormalizationConverter.get_converters( + config.pre_norm, + f"{fast_llm_prefix}.pre_norm", + f"{hf_prefix}.pre_feedforward_layernorm_2", + drop_on_export=drop_on_export, + ) + if config.post_norm is not None: + converters += LlamaNormalizationConverter.get_converters( + config.post_norm, + f"{fast_llm_prefix}.post_norm", + f"{hf_prefix}.post_feedforward_layernorm_2", + drop_on_export=drop_on_export, + ) + # router.norm is FixedRMSNorm — no learnable weight to convert. if not drop_on_export: - # Gemma4-specific router parameters without Fast-LLM equivalents converters += [ - get_parameter_converter((), f"{hf_prefix}.router.scale", drop_on_import=True), get_parameter_converter((), f"{hf_prefix}.router.per_expert_scale", drop_on_import=True), ] return converters @@ -320,17 +349,10 @@ def get_converters( class Gemma4HybridMoEMLPConverter: @classmethod def import_config(cls, config: dict) -> dict: - def make_norm(): - return {"type": "rms_norm", "epsilon": config["rms_norm_eps"]} - return { "type": "hybrid_moe", - "dense": Gemma4MLPConverter.import_config(config), + "dense": Gemma4MLPConverter.import_config(config, with_norms=True), "routed": Gemma4MoEMLPConverter.import_config(config), - "dense_pre_norm": make_norm(), - "moe_pre_norm": make_norm(), - "dense_post_norm": make_norm(), - "moe_post_norm": make_norm(), } @classmethod @@ -350,7 +372,6 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: - norm_config = config.dense_pre_norm return [ *Gemma4MLPConverter.get_converters( config.dense, @@ -365,29 +386,17 @@ def get_converters( drop_on_export=drop_on_export, ), *LlamaNormalizationConverter.get_converters( - norm_config, - f"{fast_llm_prefix}.dense_pre_norm", + config.dense.pre_norm, + f"{fast_llm_prefix}.dense.pre_norm", f"{hf_prefix}.pre_feedforward_layernorm", drop_on_export=drop_on_export, ), *LlamaNormalizationConverter.get_converters( - norm_config, - f"{fast_llm_prefix}.moe_pre_norm", - f"{hf_prefix}.pre_feedforward_layernorm_2", - drop_on_export=drop_on_export, - ), - *LlamaNormalizationConverter.get_converters( - norm_config, - f"{fast_llm_prefix}.dense_post_norm", + config.dense.post_norm, + f"{fast_llm_prefix}.dense.post_norm", f"{hf_prefix}.post_feedforward_layernorm_1", drop_on_export=drop_on_export, ), - *LlamaNormalizationConverter.get_converters( - norm_config, - f"{fast_llm_prefix}.moe_post_norm", - f"{hf_prefix}.post_feedforward_layernorm_2", - drop_on_export=drop_on_export, - ), ] diff --git a/tests/layers/test_mlp.py b/tests/layers/test_mlp.py index 70108224b..d5b82fd22 100644 --- a/tests/layers/test_mlp.py +++ b/tests/layers/test_mlp.py @@ -25,44 +25,50 @@ class HybridMoEMLPTestConfig: name: str gated: bool = False experts_per_token: int = 1 + wrapper_pre_norm: bool = False + wrapper_post_norm: bool = False dense_pre_norm: bool = False dense_post_norm: bool = False - moe_pre_norm: bool = False - moe_post_norm: bool = False + routed_pre_norm: bool = False + routed_post_norm: bool = False def get_mlp_config(self) -> HybridMoEMLPConfig: - return HybridMoEMLPConfig.from_dict( - { - "dense": { - "intermediate_size": _INTERMEDIATE_SIZE, - "gated": self.gated, - "add_linear_biases": False, - }, - "routed": { - "intermediate_size": _INTERMEDIATE_SIZE, - "gated": self.gated, - "add_linear_biases": False, - "experts": _EXPERTS, - "experts_per_token": self.experts_per_token, - }, - **({"dense_pre_norm": _NORM} if self.dense_pre_norm else {}), - **({"dense_post_norm": _NORM} if self.dense_post_norm else {}), - **({"moe_pre_norm": _NORM} if self.moe_pre_norm else {}), - **({"moe_post_norm": _NORM} if self.moe_post_norm else {}), - } - ) + dense: dict = { + "intermediate_size": _INTERMEDIATE_SIZE, + "gated": self.gated, + "add_linear_biases": False, + } + routed: dict = { + "intermediate_size": _INTERMEDIATE_SIZE, + "gated": self.gated, + "add_linear_biases": False, + "experts": _EXPERTS, + "experts_per_token": self.experts_per_token, + } + if self.dense_pre_norm: + dense["pre_norm"] = _NORM + if self.dense_post_norm: + dense["post_norm"] = _NORM + if self.routed_pre_norm: + routed["pre_norm"] = _NORM + if self.routed_post_norm: + routed["post_norm"] = _NORM + wrapper: dict = {"dense": dense, "routed": routed} + if self.wrapper_pre_norm: + wrapper["pre_norm"] = _NORM + if self.wrapper_post_norm: + wrapper["post_norm"] = _NORM + return HybridMoEMLPConfig.from_dict(wrapper) def expected_output(self, hybrid: HybridMoEMLP, input_: torch.Tensor, kwargs: dict) -> torch.Tensor: with torch.no_grad(): - dense_input = hybrid.dense_pre_norm(input_) if hybrid.dense_pre_norm is not None else input_ - moe_input = hybrid.moe_pre_norm(input_) if hybrid.moe_pre_norm is not None else input_ - dense_out, _ = hybrid.dense(dense_input, kwargs) - routed_out, _ = hybrid.routed(moe_input, kwargs) - if hybrid.dense_post_norm is not None: - dense_out = hybrid.dense_post_norm(dense_out) - if hybrid.moe_post_norm is not None: - routed_out = hybrid.moe_post_norm(routed_out) - return dense_out + routed_out + shared = hybrid.pre_norm(input_) if hybrid.pre_norm is not None else input_ + dense_out, _ = hybrid.dense(shared, kwargs) + routed_out, _ = hybrid.routed(shared, kwargs) + out = dense_out + routed_out + if hybrid.post_norm is not None: + out = hybrid.post_norm(out) + return out _test_configs = [ @@ -70,12 +76,19 @@ def expected_output(self, hybrid: HybridMoEMLP, input_: torch.Tensor, kwargs: di HybridMoEMLPTestConfig(name="gated", gated=True), HybridMoEMLPTestConfig(name="topk2", experts_per_token=2), HybridMoEMLPTestConfig(name="gated_topk2", gated=True, experts_per_token=2), - HybridMoEMLPTestConfig(name="pre_norms", dense_pre_norm=True, moe_pre_norm=True), - HybridMoEMLPTestConfig(name="post_norms", dense_post_norm=True, moe_post_norm=True), + HybridMoEMLPTestConfig(name="branch_pre_norms", dense_pre_norm=True, routed_pre_norm=True), + HybridMoEMLPTestConfig(name="branch_post_norms", dense_post_norm=True, routed_post_norm=True), + HybridMoEMLPTestConfig(name="wrapper_norms", wrapper_pre_norm=True, wrapper_post_norm=True), HybridMoEMLPTestConfig( - name="all_norms", dense_pre_norm=True, dense_post_norm=True, moe_pre_norm=True, moe_post_norm=True + name="all_norms", + wrapper_pre_norm=True, + wrapper_post_norm=True, + dense_pre_norm=True, + dense_post_norm=True, + routed_pre_norm=True, + routed_post_norm=True, ), - HybridMoEMLPTestConfig(name="asymmetric_norms", dense_pre_norm=True, moe_post_norm=True), + HybridMoEMLPTestConfig(name="asymmetric_norms", dense_pre_norm=True, routed_post_norm=True), ] From d3ef9328448be3d70ae449a0e916500c488e3edb Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 09:53:02 -0400 Subject: [PATCH 14/14] Add learnable per-expert scale to MoE router Adds an optional `router_per_expert_scale` learnable parameter applied to top-k scores after routing, matching HF Gemma 4's `router.per_expert_scale`. --- fast_llm/layers/decoder/mlp/config.py | 4 ++++ fast_llm/layers/decoder/mlp/mixture_of_experts.py | 9 +++++++++ fast_llm/models/gpt/conversion/gemma4.py | 10 ++++++---- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index cd022f407..1a7d6c579 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -114,6 +114,10 @@ class MoEMLPConfig(MLPConfig): " Set to `hidden_size ** -0.5` for Gemma-style routing.", hint=FieldHint.architecture, ) + router_per_expert_scale: OptionalParameterConfig = Field( + desc="Optional learnable per-expert scale multiplied into the router scores after top-k selection.", + hint=FieldHint.architecture, + ) experts: int = Field( default=2, desc="Number of MLP experts in a Mixture of Expert (MoE) model", diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 222b2d05f..bfa053a35 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -92,6 +92,12 @@ def __init__( peft=self._peft, ) self._router_input_scale = self._config.router_input_scale + self.router_per_expert_scale = self._config.router_per_expert_scale.get_parameter( + (TensorDim("experts", self._config.experts),), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) implementation = self._config.implementation if implementation == MoEImplementation.auto: implementation = MoEImplementation.dropless if triton_available else MoEImplementation.looped @@ -160,6 +166,9 @@ def _forward( else: raise NotImplementedError(self._config.routing) + if self.router_per_expert_scale is not None: + scores = scores * self.router_per_expert_scale[top_experts] + self._debug(scores, "router_scores", logit_dims, kwargs) self._debug(top_experts, "router_top_experts", logit_dims, kwargs) diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 317be2b76..ea3677173 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -276,6 +276,7 @@ def import_config(cls, config: dict) -> dict: "router_normalization": {"type": "fixed_rms_norm", "epsilon": eps}, "router_scale": {"enabled": True}, "router_input_scale": config["hidden_size"] ** -0.5, + "router_per_expert_scale": {"enabled": True}, } @classmethod @@ -309,6 +310,11 @@ def get_converters( f"{hf_prefix}.router.scale", drop_on_export=drop_on_export, ), + get_parameter_converter( + f"{fast_llm_prefix}.router_per_expert_scale", + f"{hf_prefix}.router.per_expert_scale", + drop_on_export=drop_on_export, + ), get_parameter_converter( f"{fast_llm_prefix}.layer_1.weight", f"{hf_prefix}.experts.gate_up_proj", @@ -339,10 +345,6 @@ def get_converters( drop_on_export=drop_on_export, ) # router.norm is FixedRMSNorm — no learnable weight to convert. - if not drop_on_export: - converters += [ - get_parameter_converter((), f"{hf_prefix}.router.per_expert_scale", drop_on_import=True), - ] return converters