diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index 7c25ce735..247d705b7 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -18,6 +18,7 @@ def triton_normalization_forward_kernel( n_cols, eps, has_bias: tl_constexpr, + has_weight: tl_constexpr, zero_centered: tl_constexpr, block_size: tl_constexpr, ): @@ -40,11 +41,13 @@ def triton_normalization_forward_kernel( tl.store(inv_var_ptr + row, inv_var) # Weight - weight = tl.load(weight_ptr + cols, mask=mask) - if zero_centered: - weight += 1 - - output = input_ * inv_var * weight + if has_weight: + weight = tl.load(weight_ptr + cols, mask=mask) + if zero_centered: + weight += 1 + output = input_ * inv_var * weight + else: + output = input_ * inv_var # Bias if has_bias: @@ -69,6 +72,7 @@ def triton_normalization_backward_kernel_1( n_rows, eps, has_bias: tl_constexpr, + has_weight: tl_constexpr, parameter_grad: tl_constexpr, zero_centered: tl_constexpr, block_size: tl_constexpr, @@ -87,10 +91,6 @@ def triton_normalization_backward_kernel_1( # Load data output = tl.load(output_ptr + offsets, mask=mask, other=0).to(tl.float32) grad_output = tl.load(grad_output_ptr + offsets, mask=mask, other=0).to(tl.float32) - weight = tl.load(weight_ptr + cols, mask=col_mask).to(tl.float32) - if zero_centered: - weight += 1 - inv_var = tl.load(inv_var_ptr + rows, mask=row_mask) # Bias @@ -99,9 +99,18 @@ def triton_normalization_backward_kernel_1( output = output - bias # Input grad - weight_regularised = tl.where(weight >= 0, tl.maximum(weight, eps), tl.minimum(weight, -eps)) - input_normalized = tl.where(mask, output / weight_regularised, 0.0) - weight_grad_output = tl.where(mask, weight * grad_output * inv_var, 0.0) + if has_weight: + weight = tl.load(weight_ptr + cols, mask=col_mask).to(tl.float32) + if zero_centered: + weight += 1 + weight_regularised = tl.where(weight >= 0, tl.maximum(weight, eps), tl.minimum(weight, -eps)) + input_normalized = tl.where(mask, output / weight_regularised, 0.0) + weight_grad_output = tl.where(mask, weight * grad_output * inv_var, 0.0) + else: + # weight == 1 everywhere: forward output = input * inv_var, so input_normalized = output + input_normalized = tl.where(mask, output, 0.0) + weight_grad_output = tl.where(mask, grad_output * inv_var, 0.0) + grad_input = weight_grad_output - input_normalized * ( tl.sum(input_normalized * weight_grad_output, axis=1)[:, None] / n_cols ) @@ -170,7 +179,7 @@ def triton_normalization_backward_kernel_2( def triton_normalization_forward( input_: torch.Tensor, - weight: torch.Tensor, + weight: torch.Tensor | None, bias: torch.Tensor | None, eps: float, training: bool, @@ -179,14 +188,15 @@ def triton_normalization_forward( # Note: Converting input automatically to training dtype to match Apex behaviour, # needed for full precision residual. # TODO: Review this? - assert weight.shape == input_.shape[-1:] - if bias is not None: - assert weight.shape == bias.shape + if weight is not None: + assert weight.shape == input_.shape[-1:] + if bias is not None: + assert weight.shape == bias.shape assert input_.is_contiguous() n_rows = input_.shape[:-1].numel() - n_cols = weight.numel() + n_cols = input_.shape[-1] - output = torch.empty_like(input_, dtype=weight.dtype) + output = torch.empty_like(input_, dtype=weight.dtype if weight is not None else input_.dtype) inv_var = torch.empty(n_rows, dtype=torch.float32, device=input_.device) block_size = triton.next_power_of_2(n_cols) @@ -202,6 +212,7 @@ def triton_normalization_forward( n_cols, eps, bias is not None, + weight is not None, zero_centered, block_size, num_warps=num_warps, @@ -217,16 +228,18 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin # We delete the context to prevent a memory leak context.clear() has_bias = bias is not None + has_weight = weight is not None - parameter_grad = weight.requires_grad - assert parameter_grad == hasattr(weight, "grad_buffer") + parameter_grad = weight.requires_grad if has_weight else False + if has_weight: + assert parameter_grad == hasattr(weight, "grad_buffer") if has_bias: assert parameter_grad == bias.requires_grad grad_output = grad_output.contiguous() n_rows = grad_output.shape[:-1].numel() - n_cols = weight.numel() + n_cols = grad_output.shape[-1] # TODO: Improve heuristics # The ones from triton tutorial (32, 128) are terrible. # These seem to match torch compile heuristics and were near-optimal for A100 tests with [8192, 4096], bf16. @@ -274,6 +287,7 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin n_rows, eps, has_bias, + has_weight, parameter_grad, zero_centered, block_size, diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index f07046a52..be0fa5de2 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -9,6 +9,7 @@ @triton_jit() def triton_rotary_kernel( input_ptr, + output_ptr, frequencies_ptr, stride_0, stride_1, @@ -30,6 +31,8 @@ def triton_rotary_kernel( input_offsets = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + stride_2 * head_offsets + offsets[None, :] input_re_ptr = input_ptr + input_offsets input_im_ptr = input_re_ptr + rotary_dim + output_re_ptr = output_ptr + input_offsets + output_im_ptr = output_re_ptr + rotary_dim if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: input_re = tl.load(input_re_ptr).to(tl.float32) @@ -54,11 +57,11 @@ def triton_rotary_kernel( out_im = input_im * frequencies_re + input_re * frequencies_im if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - tl.store(input_re_ptr, out_re) - tl.store(input_im_ptr, out_im) + tl.store(output_re_ptr, out_re) + tl.store(output_im_ptr, out_im) else: - tl.store(input_re_ptr, out_re, mask=mask) # noqa - tl.store(input_im_ptr, out_im, mask=mask) + tl.store(output_re_ptr, out_re, mask=mask) # noqa + tl.store(output_im_ptr, out_im, mask=mask) def triton_rotary_( @@ -66,19 +69,27 @@ def triton_rotary_( frequencies: torch.Tensor, is_key_value: bool = False, backward: bool = False, + inplace: bool = True, ) -> torch.Tensor: # TODO: Improve assumptions. # TODO: Make a transposed version to avoid contiguous call in key backward. # TODO: Improve block size heuristics. out = input_ + write = input_ if input_.stride(-1) != 1: # TODO: Make a transposed version to avoid contiguous call in key backward. input_ = input_.contiguous() + write = input_ + if not inplace: + out = torch.empty_like(input_) + write = out if input_.ndim == 3: input_ = input_.unsqueeze(0) + write = write.unsqueeze(0) frequencies = frequencies.unsqueeze(0) if is_key_value: input_ = input_.chunk(2, dim=-2)[0] + write = write.chunk(2, dim=-2)[0] batch_size, seq_len, num_heads, head_size = input_.shape rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) @@ -89,6 +100,7 @@ def triton_rotary_( # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( input_, + write, frequencies, input_.stride(0), input_.stride(1), diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index be40317f3..a3b9f41f3 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -104,12 +104,17 @@ def __init__( head_size_dim = TensorDim("head_size", self._config.head_size) query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, head_size_dim)) - key_value_dim = ConcatenatedTensorDim( - "key_value", - ( - CompositeTensorDim("key", (head_group_dim, head_size_dim)), - CompositeTensorDim("value", (head_group_dim, head_size_dim)), - ), + key_dim = CompositeTensorDim("key", (head_group_dim, head_size_dim)) + key_value_dim = ( + key_dim + if self._config.shared_key_value + else ConcatenatedTensorDim( + "key_value", + ( + key_dim, + CompositeTensorDim("value", (head_group_dim, head_size_dim)), + ), + ) ) self._dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim)) @@ -136,7 +141,7 @@ def __init__( lr_scale=self._lr_scale, peft=None if self._config.key_layer.apply_peft is None else self._peft, ) - if self._peft is not None and self._config.key_layer.apply_peft is None: + if self._peft is not None and self._config.key_layer.apply_peft is None and not self._config.shared_key_value: # Default: Apply to value only. # TODO: Avoid this hack. self.key_value = self._peft.apply_linear( @@ -148,6 +153,23 @@ def __init__( # Rotary embeddings. self._rotary = self._config.rotary.get_layer(head_size_dim) + # QKV norms (applied after projection, before RoPE). + self.query_norm = ( + self._config.query_norm.get_layer(head_size_dim, lr_scale=self._lr_scale, peft=None) + if self._config.query_norm is not None + else None + ) + self.key_norm = ( + self._config.key_norm.get_layer(head_size_dim, lr_scale=self._lr_scale, peft=None) + if self._config.key_norm is not None + else None + ) + self.value_norm = ( + self._config.value_norm.get_layer(head_size_dim, lr_scale=self._lr_scale, peft=None) + if self._config.value_norm is not None + else None + ) + # Output. self.dense = self._config.dense_layer.get_layer( self._dense_dim, @@ -252,10 +274,53 @@ def _query_key_value_forward( # TODO: This is probably unnecessary. handle.wait() + query_unflat = query.unflatten(1, (self._local_heads, self._config.head_size)) + if self._config.shared_key_value: + kv_unflat = key_value.unflatten(1, (self._local_head_groups, self._config.head_size)) + kv_unflat = torch.cat([kv_unflat, kv_unflat], dim=1) + else: + kv_unflat = key_value.unflatten(1, (2 * self._local_head_groups, self._config.head_size)) + + query_norm_context = None + if self._config.query_norm is not None: + if self.training: + with torch.enable_grad(): + query_leaf = query_unflat.contiguous().detach().requires_grad_() + query_normed = self.query_norm(query_leaf) + query_norm_context = (query_leaf, query_normed) + query_unflat = query_normed.detach() + else: + query_unflat = self.query_norm(query_unflat) + + key_norm_context = None + value_norm_context = None + if self._config.key_norm is not None or self._config.value_norm is not None: + key_unflat, value_unflat = kv_unflat.chunk(2, dim=1) + if self._config.key_norm is not None: + # .contiguous() is required because RMSNormalization uses .view() internally. + key_unflat = key_unflat.contiguous() + if self.training: + with torch.enable_grad(): + key_leaf = key_unflat.detach().requires_grad_() + key_normed = self.key_norm(key_leaf) + key_norm_context = (key_leaf, key_normed) + key_unflat = key_normed.detach() + else: + key_unflat = self.key_norm(key_unflat) + if self._config.value_norm is not None: + value_unflat = value_unflat.contiguous() + if self.training: + with torch.enable_grad(): + value_leaf = value_unflat.detach().requires_grad_() + value_normed = self.value_norm(value_leaf) + value_norm_context = (value_leaf, value_normed) + value_unflat = value_normed.detach() + else: + value_unflat = self.value_norm(value_unflat) + kv_unflat = torch.cat([key_unflat, value_unflat], dim=1) + query, key_value, rotary_context = self._rotary.forward_only( - query.unflatten(1, (self._local_heads, self._config.head_size)), - key_value.unflatten(1, (2 * self._local_head_groups, self._config.head_size)), - kwargs, + query_unflat, kv_unflat, kwargs, inplace_query=query_norm_context is None ) if self._sequence_data_parallel_dim.group: @@ -266,7 +331,14 @@ def _query_key_value_forward( if handle: handle.wait() - context = {"query": query_context, "key_value": key_value_context, "rotary": rotary_context} + context = { + "query": query_context, + "key_value": key_value_context, + "rotary": rotary_context, + "query_norm": query_norm_context, + "key_norm": key_norm_context, + "value_norm": value_norm_context, + } return query, key_value, context def _query_key_value_backward( @@ -283,6 +355,11 @@ def _query_key_value_backward( rotary_context = context.pop("rotary") query_grad, _ = self._rotary.backward(query_grad, None, rotary_context) + if (query_norm_context := context.pop("query_norm")) is not None: + query_leaf, query_normed = query_norm_context + query_normed.backward(query_grad) + query_grad = query_leaf.grad + # TODO: Overlap with both. input_grad = self.query.backward(query_grad.flatten(1), context.pop("query")) @@ -290,7 +367,26 @@ def _query_key_value_backward( handle.wait() _, key_value_grad = self._rotary.backward(None, key_value_grad, rotary_context) - key_value_grad = key_value_grad.flatten(1) + + key_norm_context = context.pop("key_norm") + value_norm_context = context.pop("value_norm") + if key_norm_context is not None or value_norm_context is not None: + key_grad, value_grad = key_value_grad.chunk(2, dim=1) + if key_norm_context is not None: + key_leaf, key_normed = key_norm_context + key_normed.backward(key_grad.contiguous()) + key_grad = key_leaf.grad + if value_norm_context is not None: + value_leaf, value_normed = value_norm_context + value_normed.backward(value_grad.contiguous()) + value_grad = value_leaf.grad + key_value_grad = torch.cat([key_grad, value_grad], dim=1) + + if self._config.shared_key_value: + key_grad, value_grad = key_value_grad.chunk(2, dim=1) + key_value_grad = (key_grad + value_grad).flatten(1) + else: + key_value_grad = key_value_grad.flatten(1) if self._config.head_groups == 1 and (group := self._parallel_dim.group): if self._sequence_parallel: diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index fcb5bfaf6..cc5d80e88 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -6,6 +6,7 @@ from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert @@ -122,6 +123,26 @@ class AttentionConfig(MixerConfig): desc="The implementation to use for the attention layer. Default: `flash` if supported, otherwise `backup`.", hint=FieldHint.feature, ) + query_norm: NormalizationConfig | None = Field( + default=None, + desc="Normalization applied to query vectors before RoPE, per attention head. Set to `{type: rms_norm}` to enable.", + hint=FieldHint.architecture, + ) + key_norm: NormalizationConfig | None = Field( + default=None, + desc="Normalization applied to key vectors before RoPE, per attention head. Set to `{type: rms_norm}` to enable.", + hint=FieldHint.architecture, + ) + value_norm: NormalizationConfig | None = Field( + default=None, + desc="Normalization applied to value projections per head before attention. Use `{type: fixed_rms_norm}` for a no-weight RMS norm.", + hint=FieldHint.architecture, + ) + shared_key_value: bool = Field( + default=False, + desc="Use one shared key/value projection. The projected key is reused as value before separate K/V norms.", + hint=FieldHint.architecture, + ) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 80f499748..588abb3bf 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -2,7 +2,7 @@ import math import typing -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.utils import Assert @@ -12,6 +12,7 @@ DefaultRotary, Llama3Rotary, NoRotary, + ProportionalRotary, Rotary, Rotary2D, YarnRotary, @@ -139,3 +140,28 @@ def _get_configurable_class(self) -> "type[Rotary2D]": from fast_llm.layers.attention.rotary.rotary import Rotary2D return Rotary2D + + +@config_class(dynamic_type={RotaryConfig: "proportional"}) +class ProportionalRotaryConfig(DefaultRotaryConfig): + """ + Rotary embeddings applied only to a leading fraction of head dimensions (NoPE for the rest). + Used by Gemma 4 global-attention layers (partial_rotary_factor=0.5). + """ + + _abstract = False + partial_rotary_factor: float = Field( + default=1.0, + desc="Fraction of head dimensions to apply rotary embeddings to.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + + def _validate(self) -> None: + super()._validate() + Assert.leq(self.partial_rotary_factor, 1.0) + + def _get_configurable_class(self) -> "type[ProportionalRotary]": + from fast_llm.layers.attention.rotary.rotary import ProportionalRotary + + return ProportionalRotary diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 7752e058c..4c30d0a48 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -13,6 +13,7 @@ DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, + ProportionalRotaryConfig, Rotary2DConfig, RotaryConfig, YarnRotaryConfig, @@ -108,7 +109,11 @@ def forward( @abc.abstractmethod def forward_only( - self, query: torch.Tensor | None, key_value: torch.Tensor | None, kwargs: dict[str, typing.Any] + self, + query: torch.Tensor | None, + key_value: torch.Tensor | None, + kwargs: dict[str, typing.Any], + inplace_query: bool = True, ) -> tuple[torch.Tensor | None, torch.Tensor | None, typing.Any]: pass @@ -129,7 +134,11 @@ def forward( return query, key_value def forward_only( - self, query: torch.Tensor | None, key_value: torch.Tensor | None, kwargs: dict[str, typing.Any] + self, + query: torch.Tensor | None, + key_value: torch.Tensor | None, + kwargs: dict[str, typing.Any], + inplace_query: bool = True, ) -> tuple[torch.Tensor | None, torch.Tensor | None, typing.Any]: return query, key_value, None @@ -147,19 +156,35 @@ def _forward( key_value: torch.Tensor | None, frequencies: torch.Tensor, backward: bool = False, + inplace_query: bool = True, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - rotary_fn = triton_rotary_ if TritonConfig.enabled(frequencies.device) else rotary_embeddings_real - query = None if query is None else rotary_fn(query, frequencies, backward=backward) - key_value = ( - None if key_value is None else rotary_fn(key_value, frequencies, is_key_value=True, backward=backward) - ) + if TritonConfig.enabled(frequencies.device): + query = ( + None if query is None else triton_rotary_(query, frequencies, backward=backward, inplace=inplace_query) + ) + key_value = ( + None + if key_value is None + else triton_rotary_(key_value, frequencies, is_key_value=True, backward=backward) + ) + else: + query = None if query is None else rotary_embeddings_real(query, frequencies, backward=backward) + key_value = ( + None + if key_value is None + else rotary_embeddings_real(key_value, frequencies, is_key_value=True, backward=backward) + ) return query, key_value def forward_only( - self, query: torch.Tensor | None, key_value: torch.Tensor | None, kwargs: dict[str, typing.Any] + self, + query: torch.Tensor | None, + key_value: torch.Tensor | None, + kwargs: dict[str, typing.Any], + inplace_query: bool = True, ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor]: frequencies: torch.Tensor = kwargs[AttentionKwargs.rotary_freq] - query, key_value = self._forward(query, key_value, frequencies, backward=False) + query, key_value = self._forward(query, key_value, frequencies, backward=False, inplace_query=inplace_query) return query, key_value, frequencies def backward( @@ -269,6 +294,27 @@ def _get_correction(self, beta: float, dim: int) -> float: ) +class ProportionalRotary[ConfigType: ProportionalRotaryConfig](DefaultRotary[ConfigType]): + """ + Rotary embeddings applied only to the first rotary_dims head dimensions. + The remaining NoPE dimensions pass through unchanged (zero angle → identity rotation). + """ + + def __init__(self, config: ConfigType, head_size_dim: TensorDim) -> None: + super().__init__(config, head_size_dim) + self._rotary_dims = round(self._head_size * self._config.partial_rotary_factor) + Assert.gt(self._rotary_dims, 0) + Assert.multiple(self._rotary_dims, 2) + + def _get_angle_scales(self, head_size: int, device: torch.device) -> torch.Tensor: + rotary_pairs = self._rotary_dims // 2 + nope_pairs = head_size // 2 - rotary_pairs + scales = super()._get_angle_scales(head_size, device) + if nope_pairs == 0: + return scales + return torch.cat([scales[:rotary_pairs], scales.new_zeros(nope_pairs)]) + + class Rotary2D[ConfigType: Rotary2DConfig](RotaryBase[ConfigType]): _frequencies: torch.Tensor _config: ConfigType diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 274215bf2..c84b055c6 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -138,6 +138,25 @@ def module_class(self): return RMSNormalization +@config_class(dynamic_type={NormalizationConfig: "fixed_rms_norm"}) +class FixedRMSNormConfig(NormalizationConfig): + """RMS normalization without a learnable weight (fixed unit scale). Used for value norms in Gemma-family models.""" + + _abstract = False + epsilon: float = Field( + default=1e-5, + desc="Regularizer for the division.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + + @property + def module_class(self) -> type["Normalization"]: + from fast_llm.layers.common.normalization.normalization import FixedRMSNormalization + + return FixedRMSNormalization + + @config_class(dynamic_type={NormalizationConfig: "gated_rms_norm"}) class GatedRMSNormalizationConfig(RMSNormalizationConfig): """Configuration for gated RMS normalization, which applies a learned activation gate alongside the norm weight.""" diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 2858b9370..dda12f17b 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -9,6 +9,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.normalization.config import ( + FixedRMSNormConfig, GatedRMSNormalizationConfig, LayerNormalizationConfig, NoNormalizationConfig, @@ -301,6 +302,27 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) +class FixedRMSNormalization[ConfigType: FixedRMSNormConfig](Normalization[ConfigType]): + """RMS normalization with no learnable weight (fixed unit scale).""" + + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) + self._normalized_shape = (hidden_dim.size,) + if TritonConfig.enabled(torch.device("cuda")): + self._forward = self._forward_triton + else: + self._forward = self._forward_torch + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) + + def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: + return triton_normalization_autograd(input_, None, None, self._config.epsilon, self.training, False) + + def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: + return torch.rms_norm(input_, self._normalized_shape, None, self._config.epsilon) + + class GatedRMSNormalization[ConfigType: GatedRMSNormalizationConfig](RMSNormalization[ConfigType], torch.nn.Module): """ A gated RMS normalization layer. diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index a2f2d3519..5caf24c2a 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -6,7 +6,8 @@ from fast_llm.core.distributed import ReduceOp, all_reduce, set_generator from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig -from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.config_utils.initialization import init_ones_ +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.utils import AuxiliaryLoss @@ -86,8 +87,26 @@ def __init__( ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input = return_input - self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + self.norm_1 = ( + self._config.normalization + if self._config.pre_mixer_normalization is None + else self._config.pre_mixer_normalization + ).get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + self.norm_2 = ( + self._config.normalization + if self._config.pre_mlp_normalization is None + else self._config.pre_mlp_normalization + ).get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + self.post_mixer_norm = ( + self._config.post_mixer_normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if self._config.post_mixer_normalization is not None + else None + ) + self.post_mlp_norm = ( + self._config.post_mlp_normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if self._config.post_mlp_normalization is not None + else None + ) self.mixer = self._config.mixer.get_layer( self._distributed_config, @@ -105,6 +124,13 @@ def __init__( return_bias=True, ) + self.output_scale = self._config.output_scale.get_parameter( + (scalar_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + def setup(self, distributed: Distributed) -> None: super().setup(distributed) self.mixer.setup(distributed) @@ -112,11 +138,18 @@ def setup(self, distributed: Distributed) -> None: @torch.compile def _bias_dropout_add( - self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor + self, + input_: torch.Tensor, + bias: torch.Tensor | None, + residual: torch.Tensor, + output_scale: torch.Tensor | None = None, ) -> torch.Tensor: if bias is not None: input_ = input_ + bias - return residual + torch.dropout(input_, self._config.dropout, self.training) + output = residual + torch.dropout(input_, self._config.dropout, self.training) + if output_scale is not None: + output = output * output_scale + return output def forward( self, @@ -145,14 +178,18 @@ def forward( bias = None hidden_states = self._activation_distillation_loss(hidden_states, kwargs, losses, metrics) + if self.post_mixer_norm is not None: + hidden_states = self.post_mixer_norm(hidden_states) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) self._debug(input_, "mixer_residual", hidden_dims, kwargs) hidden_states = self.norm_2(input_) self._debug(hidden_states, "norm_2", hidden_dims, kwargs) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) + if self.post_mlp_norm is not None: + hidden_states = self.post_mlp_norm(hidden_states) with set_generator(generator): - hidden_states = self._bias_dropout_add(hidden_states, bias, input_) + hidden_states = self._bias_dropout_add(hidden_states, bias, input_, self.output_scale) self._debug(hidden_states, None, hidden_dims, kwargs) if self._return_input: diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 6ab259b2b..1c0c10c87 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -3,7 +3,7 @@ import warnings from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.config_utils.parameter import combine_lr_scales +from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import _BIG_PRIMES, DistributedConfig from fast_llm.layers.block.config import BlockConfig, BlockKwargs @@ -60,6 +60,17 @@ class MLPBaseConfig(BlockWithBiasConfig): _abstract = True + pre_norm: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the MLP input.", + hint=FieldHint.architecture, + ) + post_norm: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the MLP output.", + hint=FieldHint.architecture, + ) + def get_layer( self, distributed_config: DistributedConfig, @@ -215,7 +226,32 @@ class DecoderBlockConfig(BlockConfig): ) # TODO: Review names normalization: NormalizationConfig = Field( - desc="Configuration for the block normalization layers.", + desc="Configuration for the block normalization layers. Used as default for `pre_mixer_normalization` and `pre_mlp_normalization` when not set.", + hint=FieldHint.architecture, + ) + pre_mixer_normalization: NormalizationConfig | None = Field( + default=None, + desc="Normalization applied to the residual before the mixer. Defaults to `normalization` when not set.", + hint=FieldHint.architecture, + ) + pre_mlp_normalization: NormalizationConfig | None = Field( + default=None, + desc="Normalization applied to the residual before the MLP. Defaults to `normalization` when not set." + " Set to `{type: none}` to disable independently of the pre-mixer norm.", + hint=FieldHint.architecture, + ) + post_mixer_normalization: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the mixer output before the residual add. Set to `{type: rms_norm}` to enable.", + hint=FieldHint.architecture, + ) + post_mlp_normalization: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the MLP output before the residual add. Set to `{type: rms_norm}` to enable.", + hint=FieldHint.architecture, + ) + output_scale: OptionalParameterConfig = Field( + desc="Optional learnable scalar multiplied into the block output (after the MLP residual add).", hint=FieldHint.architecture, ) # TODO: Review names diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 997cf9d2a..1a7d6c579 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -3,13 +3,15 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.config_utils.parameter import OptionalParameterConfig from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.decoder.config import MLPBaseConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.decoder.mlp.mixture_of_experts import MixtureOfExpertMLP + from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP, MixtureOfExpertMLP from fast_llm.layers.decoder.mlp.mlp import MLP @@ -97,6 +99,25 @@ class MoEMLPConfig(MLPConfig): desc="Configuration for the MoE router.", hint=FieldHint.feature, ) + router_normalization: NormalizationConfig | None = Field( + default=None, + desc="Optional normalization applied to the router input (independent of `pre_norm`, which goes to experts).", + hint=FieldHint.architecture, + ) + router_scale: OptionalParameterConfig = Field( + desc="Optional learnable per-feature scale applied to the router input after `router_normalization`.", + hint=FieldHint.architecture, + ) + router_input_scale: float = Field( + default=1.0, + desc="Constant multiplied into the router input after `router_normalization` and `router_scale`." + " Set to `hidden_size ** -0.5` for Gemma-style routing.", + hint=FieldHint.architecture, + ) + router_per_expert_scale: OptionalParameterConfig = Field( + desc="Optional learnable per-expert scale multiplied into the router scores after top-k selection.", + hint=FieldHint.architecture, + ) experts: int = Field( default=2, desc="Number of MLP experts in a Mixture of Expert (MoE) model", @@ -164,3 +185,29 @@ def _validate(self) -> None: super()._validate() Assert.leq(self.shared_experts, self.experts) Assert.leq(self.shared_experts + self.experts_per_token, self.experts) + + +@config_class(dynamic_type={MLPBaseConfig: "hybrid_moe"}) +class HybridMoEMLPConfig(MLPBaseConfig): + """Configuration for a MoE layer combining an always-active dense MLP with top-K routed experts.""" + + _abstract = False + + dense: MLPConfig = Field( + desc="Configuration for the always-active dense MLP.", + hint=FieldHint.architecture, + ) + routed: MoEMLPConfig = Field( + desc="Configuration for the top-K routed expert MLP.", + hint=FieldHint.architecture, + ) + + def _validate(self) -> None: + super()._validate() + Assert.eq(self.routed.shared_experts, 0) + + @property + def layer_class(self) -> "type[HybridMoEMLP]": + from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP + + return HybridMoEMLP diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 89979bd18..bfa053a35 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import ProcessGroup, set_generator from fast_llm.core.ops import gather_op from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig -from fast_llm.engine.config_utils.initialization import init_normal_ +from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.triton import triton_available @@ -16,12 +16,22 @@ from fast_llm.functional.utils import AuxiliaryLoss from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEImplementation, MoEMLPConfig, RoutingType +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.decoder.mlp.config import ( + HybridMoEMLPConfig, + MLPLossNames, + MoEImplementation, + MoEMLPConfig, + RoutingType, +) from fast_llm.layers.decoder.mlp.mlp import MLPBase from fast_llm.layers.language_model.loss.z_loss import z_loss from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.engine.distributed.distributed import Distributed + logger = logging.getLogger(__name__) @@ -70,6 +80,24 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) + self.router_normalization = ( + self._config.router_normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if self._config.router_normalization is not None + else None + ) + self.router_scale = self._config.router_scale.get_parameter( + (self._hidden_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self._router_input_scale = self._config.router_input_scale + self.router_per_expert_scale = self._config.router_per_expert_scale.get_parameter( + (TensorDim("experts", self._config.experts),), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) implementation = self._config.implementation if implementation == MoEImplementation.auto: implementation = MoEImplementation.dropless if triton_available else MoEImplementation.looped @@ -90,13 +118,25 @@ def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), ) + @torch.compile + def _scale_router_input(self, x: torch.Tensor, scale: torch.Tensor | None, input_scale: float) -> torch.Tensor: + if scale is not None: + x = x * scale + if input_scale != 1.0: + x = x * input_scale + return x + def _forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> tuple[torch.Tensor, None]: if isinstance(input_, TensorMeta): return TensorMeta.from_dims(input_.dims[:-1] + (self._output_dim,), "MLP output"), None hidden_states = input_.flatten(0, -2) - logits = self.router(hidden_states) + router_input = ( + self.router_normalization(hidden_states) if self.router_normalization is not None else hidden_states + ) + router_input = self._scale_router_input(router_input, self.router_scale, self._router_input_scale) + logits = self.router(router_input) hidden_token_dim = kwargs[BlockKwargs.hidden_token_dim] logit_dims = (hidden_token_dim, self._top_expert_dim) self._debug(logits, "Router logits", logit_dims, kwargs) @@ -126,10 +166,16 @@ def _forward( else: raise NotImplementedError(self._config.routing) + if self.router_per_expert_scale is not None: + scores = scores * self.router_per_expert_scale[top_experts] + self._debug(scores, "router_scores", logit_dims, kwargs) self._debug(top_experts, "router_top_experts", logit_dims, kwargs) - out = self._mlp_forward(hidden_states, scores, top_experts).view_as(input_) # noqa + expert_input = self.pre_norm(hidden_states) if self.pre_norm is not None else hidden_states + out = self._mlp_forward(expert_input, scores, top_experts).view_as(input_) # noqa + if self.post_norm is not None: + out = self.post_norm(out) self._debug(out, None, (hidden_token_dim, self._hidden_dim), kwargs) return out, None @@ -266,6 +312,83 @@ def get_loss_definitions(self) -> list[LossDef]: return loss_definitions +class HybridMoEMLP[ConfigType: HybridMoEMLPConfig](BlockWithBias[ConfigType]): + """ + MoE MLP that runs an always-active dense MLP alongside top-K routed experts and sums their outputs. + """ + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + output_dim: TensorDim | None = None, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias + ) + self._output_dim = self._hidden_dim if output_dim is None else output_dim + self.dense = config.dense.get_layer( + distributed_config, hidden_dim, output_dim=output_dim, lr_scale=lr_scale, peft=peft, return_bias=True + ) + self.routed = config.routed.get_layer( + distributed_config, hidden_dim, output_dim=output_dim, lr_scale=lr_scale, peft=peft, return_bias=True + ) + self.pre_norm = ( + config.pre_norm.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if config.pre_norm is not None + else None + ) + self.post_norm = ( + config.post_norm.get_layer(self._output_dim, lr_scale=self._lr_scale, peft=self._peft) + if config.post_norm is not None + else None + ) + + def setup(self, distributed: "Distributed") -> None: + super().setup(distributed) + self.dense.setup(distributed) + self.routed.setup(distributed) + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if isinstance(input_, TensorMeta): + return ( + TensorMeta.from_dims( + input_.dims[:-1] + (self._output_dim,), tensor_name="MLP output", dtype=input_.dtype + ), + None, + ) + if self.pre_norm is not None: + input_ = self.pre_norm(input_) + dense_out, dense_bias = self.dense(input_, kwargs, losses, metrics) + routed_out, _ = self.routed(input_, kwargs, losses, metrics) + out = dense_out + routed_out + if self.post_norm is not None: + if dense_bias is not None: + out = out + dense_bias + dense_bias = None + out = self.post_norm(out) + return out, dense_bias + + def get_loss_definitions(self) -> list[LossDef]: + return self.routed.get_loss_definitions() + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + return self.dense.get_compute_usage(input_, kwargs, config) + self.routed.get_compute_usage( + input_, kwargs, config + ) + + def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" with torch.no_grad(): diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 80599da97..504e26ac5 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -42,6 +42,16 @@ def __init__( self._output_dim = self._hidden_dim if output_dim is None else output_dim self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, self._intermediate_2_dim = self._get_intermediate_dims() + self.pre_norm = ( + self._config.pre_norm.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + if self._config.pre_norm is not None + else None + ) + self.post_norm = ( + self._config.post_norm.get_layer(self._output_dim, lr_scale=self._lr_scale, peft=self._peft) + if self._config.post_norm is not None + else None + ) self._activation_fn = ( triton_mlp_activation_autograd if TritonConfig.enabled(torch.device("cuda")) else torch_mlp_activation @@ -116,6 +126,8 @@ def _forward( ), None, ) + if self.pre_norm is not None: + input_ = self.pre_norm(input_) out = mlp_autograd( input_, None, @@ -132,6 +144,11 @@ def _forward( transposed_layer_2_weight=self.layer_2.transposed_weight, ) bias = self.layer_2.bias if self._parallel_dim.group else None + if self.post_norm is not None: + if bias is not None: + out = out + bias + bias = None + out = self.post_norm(out) # Use None for dims when output_dim differs from hidden_dim (e.g., adapter projections) # to let _debug infer dims from actual tensor shape self._debug(out, None, (kwargs.get(BlockKwargs.hidden_token_dim), self._hidden_dim), kwargs, bias=bias) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4a8efdab6..bde33f297 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -79,6 +79,12 @@ class LanguageModelEmbeddingsConfig(BlockConfig): " Affects RNG for initialization and dropout.", hint=FieldHint.performance, ) + embedding_scale: float = Field( + default=1.0, + desc="Multiplicative scale applied to word embeddings after lookup.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) @property def layer_class(self) -> "type[LanguageModelEmbedding]": @@ -119,6 +125,12 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + final_logit_softcap: float | None = Field( + default=None, + desc="Soft-cap applied to logits before loss: logits = tanh(logits / cap) * cap.", + hint=FieldHint.architecture, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) prediction_heads: int = Field( default=1, desc="Prediction heads.", diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index f01d6ad73..9574bb15c 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -79,6 +79,7 @@ def _forward( position_ids: torch.Tensor | None, mask_inputs: bool, embedding_map: torch.Tensor, + embedding_scale: float, ) -> torch.Tensor: group = self._parallel_dim.group if self._vocab_parallel: @@ -132,6 +133,8 @@ def _forward( (embedding_map,), input_[: embedding_map.size(0)], accumulate=True ) + if embedding_scale != 1.0: + embeddings = embeddings * embedding_scale with set_generator( self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): @@ -162,6 +165,7 @@ def forward( # Masking is needed with image tokens or padding. input_ is not None or kwargs[LanguageModelKwargs.num_tokens] < kwargs[LanguageModelKwargs.token_dim].size, embedding_map, + self._config.embedding_scale, ) self._debug(out, None, (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._hidden_dim), kwargs) return out diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 95be18035..22c750082 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -30,6 +30,16 @@ OUTPUT_WEIGHTS = "output_weights" +@torch.compile +def _softcap(logits: torch.Tensor, cap: float) -> torch.Tensor: + return torch.tanh(logits / cap) * cap + + +@torch.compile +def _softcap_backward(grad: torch.Tensor, softcapped: torch.Tensor, cap: float) -> torch.Tensor: + return grad * (1.0 - (softcapped / cap) ** 2) + + class LanguageModelHead[ConfigType: LanguageModelHeadConfig](Block[ConfigType]): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). @@ -249,6 +259,8 @@ def _logits_loss_forward_backward_partial( group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) + if self._config.final_logit_softcap is not None: + logits = _softcap(logits, self._config.final_logit_softcap) self._debug( logits, f"logits{"" if self._config.cross_entropy_splits == 1 else f"_{split_index}"}", @@ -273,6 +285,9 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) + if grad is not None and self._config.final_logit_softcap is not None: + grad = _softcap_backward(grad, logits, self._config.final_logit_softcap) + return sum(losses_) if losses_ else None, ( output_parallel_linear_backward(grad, context) if self.training else None ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 770139816..71981ba23 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -16,6 +16,7 @@ AutoGPTHuggingfaceCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, + Gemma4CheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -72,6 +73,7 @@ class GPTModelConfig(FastLLMModelConfig): DiffusionLlamaCheckpointFormat, AprielHybridSSMCheckpointFormat, Apriel2TextCheckpointFormat, + Gemma4CheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion/auto.py b/fast_llm/models/gpt/conversion/auto.py index 696b4f4ce..20842d611 100644 --- a/fast_llm/models/gpt/conversion/auto.py +++ b/fast_llm/models/gpt/conversion/auto.py @@ -9,6 +9,7 @@ AprielHybridSSMCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, + Gemma4CheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -17,6 +18,7 @@ ) from fast_llm.models.gpt.conversion.diffusion_dream import DiffusionDreamHuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.diffusion_llama import DiffusionLlamaHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.gemma4 import Gemma4HuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.llama import LlamaHuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.mistral import MistralHuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.mixtral import MixtralHuggingfaceCheckpointHandler @@ -38,4 +40,5 @@ class AutoGPTHuggingfaceCheckpointHandler( DiffusionLlamaCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, AprielHybridSSMCheckpointFormat.name: AprielHuggingfaceCheckpointHandler, Apriel2TextCheckpointFormat.name: Apriel2HuggingfaceCheckpointHandler, + Gemma4CheckpointFormat.name: Gemma4HuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/conversion/config.py b/fast_llm/models/gpt/conversion/config.py index 240860529..41a0828b6 100644 --- a/fast_llm/models/gpt/conversion/config.py +++ b/fast_llm/models/gpt/conversion/config.py @@ -51,3 +51,7 @@ class AprielHybridSSMCheckpointFormat(GPTHuggingfaceCheckpointFormat): class Apriel2TextCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel2_text" + + +class Gemma4CheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "gemma4" diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py new file mode 100644 index 000000000..ea3677173 --- /dev/null +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -0,0 +1,676 @@ +"""Gemma4 checkpoint format converter.""" + +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import ( + SplitWeightConverter, + WeightConverter, +) +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, ProportionalRotaryConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.common.normalization.config import FixedRMSNormConfig, RMSNormalizationConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.mlp.config import HybridMoEMLPConfig, MLPConfig, MoEMLPConfig +from fast_llm.layers.language_model.config import ( + LanguageModelConfig, + LanguageModelEmbeddingsConfig, + LanguageModelHeadConfig, +) +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.conversion.config import Gemma4CheckpointFormat +from fast_llm.models.gpt.conversion.llama import ( + KeyValueWeightConverter, + LlamaEmbeddingsConverter, + LlamaHeadConverter, + LlamaNormalizationConverter, + MLPLayer2Converter, + QueryWeightConverter, + get_parameter_converter, + get_weight_and_bias_converters, +) +from fast_llm.models.gpt.model import GPTModel +from fast_llm.utils import Assert, safe_merge_dicts + +_SLIDING_ATTENTION = "sliding_attention" +_FULL_ATTENTION = "full_attention" + + +class Gemma4MoELayer1Converter(WeightConverter): + """Converts batched gate_up_proj [experts, 2*intermediate, hidden] ↔ Fast-LLM layer_1 [experts*2*intermediate, hidden].""" + + _config: MoEMLPConfig + + def export_weight(self, weight): + (layer_1,) = weight + w = layer_1[:] + return (w.reshape(self._config.experts, -1, w.shape[-1]),) + + def import_weight(self, weight): + (gate_up_proj,) = weight + w = gate_up_proj[:] + return (w.reshape(-1, w.shape[-1]),) + + +class Gemma4MoELayer2Converter(WeightConverter): + """Converts batched down_proj [experts, hidden, intermediate] ↔ Fast-LLM layer_2 [experts*intermediate, hidden].""" + + _config: MoEMLPConfig + + def export_weight(self, weight): + (layer_2,) = weight + w = layer_2[:] + return (w.reshape(self._config.experts, -1, w.shape[-1]).permute(0, 2, 1).contiguous(),) + + def import_weight(self, weight): + (down_proj,) = weight + w = down_proj[:] + return (w.permute(0, 2, 1).reshape(-1, w.shape[1]).contiguous(),) + + +class Gemma4AttentionConverter: + @classmethod + def import_config(cls, config: dict, is_sliding: bool) -> dict: + eps = config["rms_norm_eps"] + if is_sliding: + rope_params = config["rope_parameters"][_SLIDING_ATTENTION] + rotary = {"type": "default", "theta": rope_params["rope_theta"]} + head_size = config["head_dim"] + head_groups = config["num_key_value_heads"] + window_size = config["sliding_window"] + else: + rope_params = config["rope_parameters"][_FULL_ATTENTION] + rotary = { + "type": "proportional", + "theta": rope_params["rope_theta"], + "partial_rotary_factor": rope_params["partial_rotary_factor"], + } + head_size = config["global_head_dim"] + num_global_kv_heads = config.get("num_global_key_value_heads") + head_groups = config["num_key_value_heads"] if num_global_kv_heads is None else num_global_kv_heads + window_size = None + out = { + "heads": config["num_attention_heads"], + "head_groups": head_groups, + "head_size": head_size, + "add_linear_biases": False, + "dropout": config["attention_dropout"], + "softmax_scale_power": 0, + "rotary": rotary, + "query_norm": {"type": "rms_norm", "epsilon": eps}, + "key_norm": {"type": "rms_norm", "epsilon": eps}, + "value_norm": {"type": "fixed_rms_norm", "epsilon": eps}, + } + if not is_sliding and config.get("attention_k_eq_v", False): + out["shared_key_value"] = True + if window_size is not None: + out["window_size"] = window_size + return out + + @classmethod + def export_config(cls, sliding_config: AttentionConfig, full_config: AttentionConfig) -> dict: + Assert.custom(isinstance, sliding_config, AttentionConfig) + Assert.custom(isinstance, full_config, AttentionConfig) + assert not sliding_config.add_linear_biases + assert isinstance(sliding_config.rotary, DefaultRotaryConfig) + assert isinstance(full_config.rotary, ProportionalRotaryConfig) + Assert.custom(isinstance, sliding_config.query_norm, RMSNormalizationConfig) + Assert.custom(isinstance, sliding_config.key_norm, RMSNormalizationConfig) + Assert.custom(isinstance, sliding_config.value_norm, FixedRMSNormConfig) + eps = sliding_config.query_norm.epsilon + num_global_kv_heads = ( + None if full_config.head_groups == sliding_config.head_groups else full_config.head_groups + ) + return { + "num_attention_heads": sliding_config.heads, + "num_key_value_heads": sliding_config.head_groups, + "head_dim": sliding_config.head_size, + "global_head_dim": full_config.head_size, + "num_global_key_value_heads": num_global_kv_heads, + "attention_bias": False, + "attention_dropout": sliding_config.dropout, + "sliding_window": sliding_config.window_size, + "rms_norm_eps": eps, + "attention_k_eq_v": full_config.shared_key_value, + "rope_parameters": { + _SLIDING_ATTENTION: { + "rope_type": "default", + "rope_theta": sliding_config.rotary.theta, + }, + _FULL_ATTENTION: { + "rope_type": "proportional", + "rope_theta": full_config.rotary.theta, + "partial_rotary_factor": full_config.rotary.partial_rotary_factor, + }, + }, + } + + @classmethod + def get_converters( + cls, + config: AttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + if config.shared_key_value: + # K=V: single k_proj reused as value; no v_proj in HF + kv_converters = get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + f"{hf_prefix}.k_proj", + False, + drop_on_export=drop_on_export, + ) + else: + kv_converters = get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), + False, + KeyValueWeightConverter, + config, + drop_on_export=drop_on_export, + ) + converters = [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.query", + f"{hf_prefix}.q_proj", + False, + QueryWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *kv_converters, + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.o_proj", + False, + drop_on_export=drop_on_export, + ), + ] + if config.query_norm is not None: + converters += LlamaNormalizationConverter.get_converters( + config.query_norm, + f"{fast_llm_prefix}.query_norm", + f"{hf_prefix}.q_norm", + drop_on_export=drop_on_export, + ) + if config.key_norm is not None: + converters += LlamaNormalizationConverter.get_converters( + config.key_norm, + f"{fast_llm_prefix}.key_norm", + f"{hf_prefix}.k_norm", + drop_on_export=drop_on_export, + ) + # value_norm is FixedRMSNorm — no learnable weight to convert + return converters + + +class Gemma4MLPConverter: + @classmethod + def import_config(cls, config: dict, with_norms: bool = False) -> dict: + out = { + "intermediate_size": config["intermediate_size"], + "add_linear_biases": False, + "activation": ActivationType.from_hf_name(config["hidden_activation"]), + "gated": True, + } + if with_norms: + eps = config["rms_norm_eps"] + out["pre_norm"] = {"type": "rms_norm", "epsilon": eps} + out["post_norm"] = {"type": "rms_norm", "epsilon": eps} + return out + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + Assert.custom(isinstance, config, MLPConfig) + assert config.gated + assert not config.add_linear_biases + return { + "intermediate_size": config.intermediate_size, + "hidden_activation": config.activation.hf_name, + } + + @classmethod + def get_converters( + cls, + config: MLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_1", + (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), + False, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_2", + f"{hf_prefix}.down_proj", + False, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + + +class Gemma4MoEMLPConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + eps = config["rms_norm_eps"] + return { + "type": "moe", + "intermediate_size": config["moe_intermediate_size"], + "add_linear_biases": False, + "activation": ActivationType.from_hf_name(config["hidden_activation"]), + "gated": True, + "experts": config["num_experts"], + "experts_per_token": config["top_k_experts"], + "pre_norm": {"type": "rms_norm", "epsilon": eps}, + "post_norm": {"type": "rms_norm", "epsilon": eps}, + "router_normalization": {"type": "fixed_rms_norm", "epsilon": eps}, + "router_scale": {"enabled": True}, + "router_input_scale": config["hidden_size"] ** -0.5, + "router_per_expert_scale": {"enabled": True}, + } + + @classmethod + def export_config(cls, config: MoEMLPConfig) -> dict: + Assert.custom(isinstance, config, MoEMLPConfig) + assert config.gated + assert not config.add_linear_biases + return { + "num_experts": config.experts, + "top_k_experts": config.experts_per_token, + "moe_intermediate_size": config.intermediate_size, + } + + @classmethod + def get_converters( + cls, + config: MoEMLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + converters = [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.router", + f"{hf_prefix}.router.proj", + False, + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.router_scale", + f"{hf_prefix}.router.scale", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.router_per_expert_scale", + f"{hf_prefix}.router.per_expert_scale", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.layer_1.weight", + f"{hf_prefix}.experts.gate_up_proj", + Gemma4MoELayer1Converter, + config, + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.layer_2.weight", + f"{hf_prefix}.experts.down_proj", + Gemma4MoELayer2Converter, + config, + drop_on_export=drop_on_export, + ), + ] + if config.pre_norm is not None: + converters += LlamaNormalizationConverter.get_converters( + config.pre_norm, + f"{fast_llm_prefix}.pre_norm", + f"{hf_prefix}.pre_feedforward_layernorm_2", + drop_on_export=drop_on_export, + ) + if config.post_norm is not None: + converters += LlamaNormalizationConverter.get_converters( + config.post_norm, + f"{fast_llm_prefix}.post_norm", + f"{hf_prefix}.post_feedforward_layernorm_2", + drop_on_export=drop_on_export, + ) + # router.norm is FixedRMSNorm — no learnable weight to convert. + return converters + + +class Gemma4HybridMoEMLPConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "type": "hybrid_moe", + "dense": Gemma4MLPConverter.import_config(config, with_norms=True), + "routed": Gemma4MoEMLPConverter.import_config(config), + } + + @classmethod + def export_config(cls, config: HybridMoEMLPConfig) -> dict: + Assert.custom(isinstance, config, HybridMoEMLPConfig) + return safe_merge_dicts( + Gemma4MLPConverter.export_config(config.dense), + Gemma4MoEMLPConverter.export_config(config.routed), + {"enable_moe_block": True}, + ) + + @classmethod + def get_converters( + cls, + config: HybridMoEMLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *Gemma4MLPConverter.get_converters( + config.dense, + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.mlp", + drop_on_export=drop_on_export, + ), + *Gemma4MoEMLPConverter.get_converters( + config.routed, + f"{fast_llm_prefix}.routed", + hf_prefix, + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + config.dense.pre_norm, + f"{fast_llm_prefix}.dense.pre_norm", + f"{hf_prefix}.pre_feedforward_layernorm", + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + config.dense.post_norm, + f"{fast_llm_prefix}.dense.post_norm", + f"{hf_prefix}.post_feedforward_layernorm_1", + drop_on_export=drop_on_export, + ), + ] + + +class Gemma4BlockConverter: + @classmethod + def import_config(cls, config: dict, is_sliding: bool) -> dict: + def make_norm(): + return {"type": "rms_norm", "epsilon": config["rms_norm_eps"]} + + out = { + "mixer": Gemma4AttentionConverter.import_config(config, is_sliding), + "normalization": make_norm(), + "post_mixer_normalization": make_norm(), + "post_mlp_normalization": make_norm(), + # HF stores `layer_scalar` as a non-trained buffer; freeze on our side to match. + "output_scale": {"enabled": True, "lr_scale": 0}, + } + if config.get("enable_moe_block"): + out["mlp"] = Gemma4HybridMoEMLPConverter.import_config(config) + out["pre_mlp_normalization"] = {"type": "none"} + else: + out["mlp"] = Gemma4MLPConverter.import_config(config) + out["pre_mlp_normalization"] = make_norm() + return out + + @classmethod + def export_config(cls, sliding_config: DecoderBlockConfig, full_config: DecoderBlockConfig) -> dict: + Assert.custom(isinstance, sliding_config, DecoderBlockConfig) + norm_config = sliding_config.normalization + Assert.custom(isinstance, norm_config, RMSNormalizationConfig) + is_moe = isinstance(sliding_config.mlp, HybridMoEMLPConfig) + out = safe_merge_dicts( + Gemma4AttentionConverter.export_config(sliding_config.mixer, full_config.mixer), + LlamaNormalizationConverter.export_config(norm_config), + ) + if is_moe: + out = safe_merge_dicts(out, Gemma4HybridMoEMLPConverter.export_config(sliding_config.mlp)) + else: + out = safe_merge_dicts(out, Gemma4MLPConverter.export_config(sliding_config.mlp)) + out["enable_moe_block"] = False + return out + + @classmethod + def get_converters( + cls, + config: DecoderBlockConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + is_moe = isinstance(config.mlp, HybridMoEMLPConfig) + converters = [ + *Gemma4AttentionConverter.get_converters( + config.mixer, + f"{fast_llm_prefix}.mixer", + f"{hf_prefix}.self_attn", + drop_on_export=drop_on_export, + ), + ] + if is_moe: + converters += Gemma4HybridMoEMLPConverter.get_converters( + config.mlp, + f"{fast_llm_prefix}.mlp", + hf_prefix, + drop_on_export=drop_on_export, + ) + else: + converters += Gemma4MLPConverter.get_converters( + config.mlp, + f"{fast_llm_prefix}.mlp", + f"{hf_prefix}.mlp", + drop_on_export=drop_on_export, + ) + converters += LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_2", + f"{hf_prefix}.pre_feedforward_layernorm", + drop_on_export=drop_on_export, + ) + converters += [ + *LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_1", + f"{hf_prefix}.input_layernorm", + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + config.post_mixer_normalization, + f"{fast_llm_prefix}.post_mixer_norm", + f"{hf_prefix}.post_attention_layernorm", + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + config.post_mlp_normalization, + f"{fast_llm_prefix}.post_mlp_norm", + f"{hf_prefix}.post_feedforward_layernorm", + drop_on_export=drop_on_export, + ), + ] + converters.append( + get_parameter_converter( + f"{fast_llm_prefix}.output_scale", + f"{hf_prefix}.layer_scalar", + drop_on_export=drop_on_export, + ) + ) + return converters + + +class Gemma4DecoderConverter: + block_converter_class: typing.ClassVar[type[Gemma4BlockConverter]] = Gemma4BlockConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + layer_types = config["layer_types"] + unique_types = list(dict.fromkeys(layer_types)) + blocks = { + layer_type: cls.block_converter_class.import_config(config, layer_type == _SLIDING_ATTENTION) + for layer_type in unique_types + } + return { + "type": "pattern", + "blocks": blocks, + "pattern": layer_types, + "num_blocks": config["num_hidden_layers"], + } + + @classmethod + def export_config(cls, config: PatternBlockSequenceConfig | FixedBlockSequenceConfig) -> dict: + Assert.custom(isinstance, config, PatternBlockSequenceConfig) + Assert.incl(_SLIDING_ATTENTION, config.blocks) + Assert.incl(_FULL_ATTENTION, config.blocks) + return safe_merge_dicts( + cls.block_converter_class.export_config( + config.blocks[_SLIDING_ATTENTION], + config.blocks[_FULL_ATTENTION], + ), + { + "num_hidden_layers": config.num_blocks, + "layer_types": list(config.expanded_pattern), + }, + ) + + @classmethod + def get_converters( + cls, + config: PatternBlockSequenceConfig | FixedBlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + Assert.custom(isinstance, config, PatternBlockSequenceConfig) + converters = [] + for block_index in range(config.num_blocks): + block_config = config.blocks[config.expanded_pattern[block_index]] + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export=drop_on_export, + ) + return converters + + +class Gemma4EmbeddingsConverter(LlamaEmbeddingsConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "vocab_size": config["vocab_size"], + "embedding_scale": config["hidden_size"] ** 0.5, + } + + @classmethod + def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict: + Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig) + assert not config.position_embeddings.enabled + return {"vocab_size": config.vocab_size} + + +class Gemma4HeadConverter(LlamaHeadConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + out = {"normalization": LlamaNormalizationConverter.import_config(config)} + if (softcap := config.get("final_logit_softcapping")) is not None: + out["final_logit_softcap"] = softcap + return out + + @classmethod + def export_config(cls, config: LanguageModelHeadConfig) -> dict: + out = LlamaNormalizationConverter.export_config(config.normalization) + if config.final_logit_softcap is not None: + out["final_logit_softcapping"] = config.final_logit_softcap + return out + + @classmethod + def get_converters( + cls, + config: LanguageModelConfig, + exported_config: dict, + ) -> list[WeightConverter]: + return [ + *LlamaNormalizationConverter.get_converters( + config.head.normalization, + "head.final_norm", + "model.norm", + ), + get_parameter_converter( + "head.output_weights", + "lm_head.weight", + drop_on_import=exported_config["tie_word_embeddings"], + drop_on_export=exported_config["tie_word_embeddings"], + ), + ] + + +class Gemma4BaseModelConverter: + decoder_converter_class: typing.ClassVar[type[Gemma4DecoderConverter]] = Gemma4DecoderConverter + embeddings_converter_class: typing.ClassVar[type[Gemma4EmbeddingsConverter]] = Gemma4EmbeddingsConverter + head_converter_class: typing.ClassVar[type[Gemma4HeadConverter]] = Gemma4HeadConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "embeddings": cls.embeddings_converter_class.import_config(config), + "decoder": cls.decoder_converter_class.import_config(config), + "head": cls.head_converter_class.import_config(config), + "hidden_size": config["hidden_size"], + "tied_embedding_weight": config["tie_word_embeddings"], + } + + @classmethod + def export_config(cls, config: GPTBaseModelConfig) -> dict: + Assert.custom(isinstance, config, GPTBaseModelConfig) + return safe_merge_dicts( + cls.embeddings_converter_class.export_config(config.embeddings), + cls.decoder_converter_class.export_config(config.decoder), + cls.head_converter_class.export_config(config.head), + { + "tie_word_embeddings": config.tied_embedding_weight, + "hidden_size": config.hidden_size, + # TODO: Implement Per-Layer Embeddings (PLE). Gemma4TextConfig defaults to 256; + # explicitly zero to disable the feature in the exported model until Fast-LLM + # supports it natively. + "hidden_size_per_layer_input": 0, + # Fast-LLM is text-only; bidirectional attention (used for vision tokens in the + # multimodal model) is not implemented. + "use_bidirectional_attention": None, + }, + ) + + @classmethod + def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + return [ + *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), + *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), + *cls.head_converter_class.get_converters(config, exported_config), + ] + + +class Gemma4HuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + _model: GPTModel + _model_class: typing.ClassVar = GPTModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = Gemma4CheckpointFormat + architecture: typing.ClassVar[str] = "Gemma4ForCausalLM" + base_model_converter_class: typing.ClassVar[type[Gemma4BaseModelConverter]] = Gemma4BaseModelConverter + + @classmethod + def get_huggingface_model_type(cls) -> str: + return "gemma4_text" + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.Gemma4TextConfig diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 05ee6c778..3cbb3e1c3 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -44,6 +44,10 @@ class AttentionTestConfig: head_size: int = _HEAD_SIZE causal: bool = True window_size: int | None = None + query_norm: bool = False + key_norm: bool = False + value_norm: bool = False + shared_key_value: bool = False rotary: bool = False rotary_theta: float = 10000.0 @@ -66,6 +70,14 @@ def get_attention_config(self, implementation: str = "backup") -> AttentionConfi } if self.window_size is not None: config["window_size"] = self.window_size + if self.query_norm: + config["query_norm"] = {"type": "rms_norm"} + if self.key_norm: + config["key_norm"] = {"type": "rms_norm"} + if self.value_norm: + config["value_norm"] = {"type": "fixed_rms_norm"} + if self.shared_key_value: + config["shared_key_value"] = True if self.rotary: config["rotary"] = {"type": "default", "theta": self.rotary_theta} return AttentionConfig.from_dict(config) @@ -77,16 +89,33 @@ def expected_output( lengths: list[int], ) -> torch.Tensor: """ - Independent reference: plain F.linear + rotary + per-document einsum attention. - No calls to Fast-LLM attention internals. + Independent reference: plain F.linear + torch.rms_norm + rotary + per-document einsum attention. + No calls to Fast-LLM attention or norm internals. """ with torch.no_grad(): q = torch.nn.functional.linear(input_, attention.query.weight.detach()).unflatten( 1, (self.heads, self.head_size) ) - kv = torch.nn.functional.linear(input_, attention.key_value.weight.detach()).unflatten( - 1, (2 * self.kv_heads, self.head_size) - ) + if self.shared_key_value: + key_projected = torch.nn.functional.linear(input_, attention.key_value.weight.detach()).unflatten( + 1, (self.kv_heads, self.head_size) + ) + kv = torch.cat([key_projected, key_projected], dim=1) + else: + kv = torch.nn.functional.linear(input_, attention.key_value.weight.detach()).unflatten( + 1, (2 * self.kv_heads, self.head_size) + ) + + if self.query_norm: + q = torch.rms_norm(q, (self.head_size,), attention.query_norm.weight.detach(), 1e-5) + if self.key_norm: + key_normed = torch.rms_norm( + kv[:, : self.kv_heads, :], (self.head_size,), attention.key_norm.weight.detach(), 1e-5 + ) + kv = torch.cat([key_normed, kv[:, self.kv_heads :, :]], dim=1) + if self.value_norm: + value_normed = torch.rms_norm(kv[:, self.kv_heads :, :], (self.head_size,), None, 1e-5) + kv = torch.cat([kv[:, : self.kv_heads, :], value_normed], dim=1) if self.rotary: freqs = _compute_rotary_freqs(input_.shape[0], self.head_size, self.rotary_theta, input_.device) @@ -94,6 +123,7 @@ def expected_output( k_rotated = _apply_rotary(kv[:, : self.kv_heads, :], freqs) kv = torch.cat([k_rotated, kv[:, self.kv_heads :, :]], dim=1) + k, v = kv[:, : self.kv_heads, :], kv[:, self.kv_heads :, :] scale = self.head_size**-0.5 @@ -136,11 +166,44 @@ def expected_output( ("causal_rotary", {"causal": True, "rotary": True}), ] -_attention_test_configs = [ - AttentionTestConfig(name=base_name, **base_kwargs) - for base_name, base_kwargs in _base_attention_cases + _attention_rotary_cases +_attention_norm_variants = [ + ("no_norm", {}), + ("query_norm", {"query_norm": True}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("both_norms", {"query_norm": True, "key_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), +] + +_attention_shared_key_value_cases = [ + ("shared_key_value", {"shared_key_value": True}), ] +_attention_shared_key_value_norm_variants = [ + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), +] + +_attention_test_configs = ( + [ + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) + for base_name, base_kwargs in _base_attention_cases + for variant_name, variant_kwargs in _attention_norm_variants + ] + + [ + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) + for base_name, base_kwargs in _attention_rotary_cases + for variant_name, variant_kwargs in _attention_norm_variants + ] + + [ + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) + for base_name, base_kwargs in _attention_shared_key_value_cases + for variant_name, variant_kwargs in _attention_shared_key_value_norm_variants + ] +) + _attention_lengths = [ [15], [6, 9], diff --git a/tests/layers/test_decoder_block.py b/tests/layers/test_decoder_block.py new file mode 100644 index 000000000..a3b778293 --- /dev/null +++ b/tests/layers/test_decoder_block.py @@ -0,0 +1,127 @@ +import dataclasses +import functools + +import pytest +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.decoder.block import DecoderBlock +from fast_llm.layers.decoder.config import DecoderBlockConfig +from tests.utils.utils import get_stage + +_NUM_TOKENS = 16 +_HIDDEN_SIZE = 64 +_HEADS = 4 +_KV_HEADS = 2 +_HEAD_SIZE = 16 +_INTERMEDIATE_SIZE = 128 + + +@dataclasses.dataclass +class PostNormTestConfig: + name: str + post_mixer_norm: bool = False + post_mlp_norm: bool = False + output_scale: float | None = None + + def get_block_config(self) -> DecoderBlockConfig: + config_dict: dict = { + "mixer": { + "heads": _HEADS, + "head_groups": _KV_HEADS, + "head_size": _HEAD_SIZE, + "add_linear_biases": False, + "implementation": "backup", + }, + "mlp": { + "intermediate_size": _INTERMEDIATE_SIZE, + "add_linear_biases": False, + }, + "normalization": {"type": "rms_norm"}, + } + if self.post_mixer_norm: + config_dict["post_mixer_normalization"] = {"type": "rms_norm"} + if self.post_mlp_norm: + config_dict["post_mlp_normalization"] = {"type": "rms_norm"} + if self.output_scale is not None: + config_dict["output_scale"] = {"enabled": True} + return DecoderBlockConfig.from_dict(config_dict) + + @functools.cached_property + def threshold(self) -> float: + return 1e-5 + + def expected_output(self, block: DecoderBlock, input_: torch.Tensor, kwargs: dict) -> torch.Tensor: + with torch.no_grad(): + norm1_out = block.norm_1(input_) + mixer_hidden, mixer_bias = block.mixer(norm1_out, kwargs) + if block.post_mixer_norm is not None: + mixer_hidden = block.post_mixer_norm(mixer_hidden) + if mixer_bias is not None: + mixer_hidden = mixer_hidden + mixer_bias + after_mixer = input_ + mixer_hidden + + norm2_out = block.norm_2(after_mixer) + mlp_hidden, mlp_bias = block.mlp(norm2_out, kwargs) + if block.post_mlp_norm is not None: + mlp_hidden = block.post_mlp_norm(mlp_hidden) + if mlp_bias is not None: + mlp_hidden = mlp_hidden + mlp_bias + output = after_mixer + mlp_hidden + if self.output_scale is not None: + output = output * self.output_scale + return output + + +_base_post_norm_cases = [ + ("no_post_norms", {}), + ("post_mixer_norm", {"post_mixer_norm": True}), + ("post_mlp_norm", {"post_mlp_norm": True}), + ("both_post_norms", {"post_mixer_norm": True, "post_mlp_norm": True}), + ("output_scale", {"output_scale": 2.5}), +] + +_post_norm_test_configs = [PostNormTestConfig(name=name, **kwargs) for name, kwargs in _base_post_norm_cases] + + +@pytest.mark.parametrize( + "test_config", + [pytest.param(c, id=c.name) for c in _post_norm_test_configs], +) +def test_post_norms(test_config: PostNormTestConfig): + distributed_config = DistributedConfig(use_cuda=torch.cuda.is_available()) + distributed = Distributed(distributed_config) + hidden_dim = TensorDim("hidden", _HIDDEN_SIZE) + block: DecoderBlock = test_config.get_block_config().get_layer( + distributed_config, hidden_dim, lr_scale=None, peft=None + ) + get_stage([block], distributed) + block.eval() + + device = distributed.device + if test_config.output_scale is not None: + with torch.no_grad(): + block.output_scale.fill_(test_config.output_scale) + input_ = torch.randn(_NUM_TOKENS, _HIDDEN_SIZE, device=device) + + token_dim = TensorDim("token", _NUM_TOKENS) + kwargs = { + AttentionKwargs.sequence_k_dim: TensorDim("sequence_k", _NUM_TOKENS), + AttentionKwargs.token_dim: token_dim, + AttentionKwargs.hidden_token_dim: token_dim, + AttentionKwargs.key_value_token_dim: token_dim, + AttentionKwargs.sequence_length: _NUM_TOKENS, + AttentionKwargs.document_index_k: torch.zeros(_NUM_TOKENS, dtype=torch.int64, device=device), + AttentionKwargs.document_index_q: torch.zeros(_NUM_TOKENS, dtype=torch.int64, device=device), + AttentionKwargs.device: device, + } + block.preprocess(kwargs) + + with torch.no_grad(): + output = block(input_, kwargs) + + expected = test_config.expected_output(block, input_, kwargs) + torch.testing.assert_close(output, expected, rtol=test_config.threshold, atol=test_config.threshold) diff --git a/tests/layers/test_embedding.py b/tests/layers/test_embedding.py index b11d21ecc..2177bb63b 100644 --- a/tests/layers/test_embedding.py +++ b/tests/layers/test_embedding.py @@ -21,6 +21,7 @@ @dataclasses.dataclass class EmbeddingTestConfig: name: str + embedding_scale: float = 1.0 compute_dtype: DataType = DataType.float32 full_precision_residual: bool = False with_position_embeddings: bool = False @@ -33,6 +34,7 @@ def residual_dtype(self) -> torch.dtype: def get_config(self) -> GPTModelConfig: embeddings: dict = { "vocab_size": VOCAB_SIZE, + "embedding_scale": self.embedding_scale, "full_precision_residual": self.full_precision_residual, } if self.with_position_embeddings: @@ -88,6 +90,9 @@ def get_reference_output(self, layer: LanguageModelEmbedding, kwargs: dict) -> t if mask_inputs: embeddings = embeddings * token_mask.unsqueeze(-1) + if self.embedding_scale != 1.0: + embeddings = embeddings * self.embedding_scale + return embeddings.to(dtype=self.residual_dtype) @@ -103,6 +108,7 @@ def get_reference_output(self, layer: LanguageModelEmbedding, kwargs: dict) -> t ("float32", {}), ("bfloat16", {"compute_dtype": DataType.bfloat16}), ("full_precision_residual", {"full_precision_residual": True}), + ("embedding_scale", {"embedding_scale": 2.0}), ] diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index aa50fbb5e..5832fea3f 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -28,6 +28,7 @@ class LMHeadTestConfig: z_loss: bool | float = False grpo_loss: bool | float = False logits_scale_factor: float = 1.0 + final_logit_softcap: float | None = None compute_dtype: DataType = DataType.float32 full_precision_residual: bool = False loss_masking: bool = False @@ -53,6 +54,8 @@ def get_config(self) -> GPTModelConfig: "cross_entropy_splits": self.num_splits, "prediction_heads": self.prediction_heads, } + if self.final_logit_softcap is not None: + head_config["final_logit_softcap"] = self.final_logit_softcap losses = {} if self.label_loss is not False: losses["label"] = {"type": "label"} @@ -167,6 +170,10 @@ def get_reference_outputs( hidden = torch.rms_norm(input_.to(normalization_weight.dtype), input_.shape[-1:], normalization_weight, 1e-5) logits = torch.nn.functional.linear(hidden, logit_weight).float() + if self.final_logit_softcap is not None: + cap = self.final_logit_softcap + logits = torch.tanh(logits / cap) * cap + if self.logits_scale_factor is not None: logits = logits * self.logits_scale_factor @@ -248,6 +255,7 @@ def _add_configs(base_name: str, **kwargs): _add_configs("bfloat16", compute_dtype=DataType.bfloat16) _add_configs("full_precision_residual", full_precision_residual=True) _add_configs("logit_scaling", logits_scale_factor=5.0) +_add_configs("final_logit_softcap", final_logit_softcap=2.0) _add_configs("tied_embedding_weight", tied_embedding_weight=True) _add_configs("multi_token_prediction", prediction_heads=2) _add_configs("label_loss", label_loss=True) diff --git a/tests/layers/test_mlp.py b/tests/layers/test_mlp.py new file mode 100644 index 000000000..d5b82fd22 --- /dev/null +++ b/tests/layers/test_mlp.py @@ -0,0 +1,116 @@ +import dataclasses + +import pytest +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.decoder.mlp.config import HybridMoEMLPConfig +from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP +from fast_llm.utils import Assert +from tests.utils.utils import get_stage + +_NUM_TOKENS = 128 +_HIDDEN_SIZE = 128 +_INTERMEDIATE_SIZE = 128 +_EXPERTS = 4 + +_NORM = {"type": "rms_norm"} + + +@dataclasses.dataclass +class HybridMoEMLPTestConfig: + name: str + gated: bool = False + experts_per_token: int = 1 + wrapper_pre_norm: bool = False + wrapper_post_norm: bool = False + dense_pre_norm: bool = False + dense_post_norm: bool = False + routed_pre_norm: bool = False + routed_post_norm: bool = False + + def get_mlp_config(self) -> HybridMoEMLPConfig: + dense: dict = { + "intermediate_size": _INTERMEDIATE_SIZE, + "gated": self.gated, + "add_linear_biases": False, + } + routed: dict = { + "intermediate_size": _INTERMEDIATE_SIZE, + "gated": self.gated, + "add_linear_biases": False, + "experts": _EXPERTS, + "experts_per_token": self.experts_per_token, + } + if self.dense_pre_norm: + dense["pre_norm"] = _NORM + if self.dense_post_norm: + dense["post_norm"] = _NORM + if self.routed_pre_norm: + routed["pre_norm"] = _NORM + if self.routed_post_norm: + routed["post_norm"] = _NORM + wrapper: dict = {"dense": dense, "routed": routed} + if self.wrapper_pre_norm: + wrapper["pre_norm"] = _NORM + if self.wrapper_post_norm: + wrapper["post_norm"] = _NORM + return HybridMoEMLPConfig.from_dict(wrapper) + + def expected_output(self, hybrid: HybridMoEMLP, input_: torch.Tensor, kwargs: dict) -> torch.Tensor: + with torch.no_grad(): + shared = hybrid.pre_norm(input_) if hybrid.pre_norm is not None else input_ + dense_out, _ = hybrid.dense(shared, kwargs) + routed_out, _ = hybrid.routed(shared, kwargs) + out = dense_out + routed_out + if hybrid.post_norm is not None: + out = hybrid.post_norm(out) + return out + + +_test_configs = [ + HybridMoEMLPTestConfig(name="basic"), + HybridMoEMLPTestConfig(name="gated", gated=True), + HybridMoEMLPTestConfig(name="topk2", experts_per_token=2), + HybridMoEMLPTestConfig(name="gated_topk2", gated=True, experts_per_token=2), + HybridMoEMLPTestConfig(name="branch_pre_norms", dense_pre_norm=True, routed_pre_norm=True), + HybridMoEMLPTestConfig(name="branch_post_norms", dense_post_norm=True, routed_post_norm=True), + HybridMoEMLPTestConfig(name="wrapper_norms", wrapper_pre_norm=True, wrapper_post_norm=True), + HybridMoEMLPTestConfig( + name="all_norms", + wrapper_pre_norm=True, + wrapper_post_norm=True, + dense_pre_norm=True, + dense_post_norm=True, + routed_pre_norm=True, + routed_post_norm=True, + ), + HybridMoEMLPTestConfig(name="asymmetric_norms", dense_pre_norm=True, routed_post_norm=True), +] + + +@pytest.mark.parametrize("config", [pytest.param(c, id=c.name) for c in _test_configs]) +def test_hybrid_moe_mlp(config: HybridMoEMLPTestConfig) -> None: + distributed_config = DistributedConfig(use_cuda=torch.cuda.is_available()) + distributed = Distributed(distributed_config) + device = distributed.device + hidden_dim = TensorDim("hidden", _HIDDEN_SIZE) + + hybrid: HybridMoEMLP = config.get_mlp_config().get_layer( + distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + get_stage([hybrid], distributed) + hybrid.eval() + + input_ = torch.randn(_NUM_TOKENS, _HIDDEN_SIZE, device=device) + token_dim = TensorDim("tokens", _NUM_TOKENS) + kwargs = {BlockKwargs.hidden_token_dim: token_dim} + + with torch.no_grad(): + output = hybrid(input_, kwargs) + + expected = config.expected_output(hybrid, input_, kwargs) + Assert.rms_close_relative(output, expected, 1e-5, 1e-7) diff --git a/tests/layers/test_rotary.py b/tests/layers/test_rotary.py index 13f7575b6..7557912bd 100644 --- a/tests/layers/test_rotary.py +++ b/tests/layers/test_rotary.py @@ -10,6 +10,7 @@ from fast_llm.layers.attention.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, + ProportionalRotaryConfig, Rotary2DConfig, RotaryConfig, YarnRotaryConfig, @@ -77,6 +78,8 @@ class RotaryTestConfig: head_size: int rotary_type: str = "default" theta: float = 10000.0 + # proportional + partial_rotary_factor: float = 1.0 # llama3 and yarn scale_factor: float = 8.0 original_context_length: int = 8192 @@ -96,6 +99,8 @@ def attention_factor(self) -> float: def get_rotary_config(self) -> RotaryConfig: if self.rotary_type == "default": return DefaultRotaryConfig(theta=self.theta) + if self.rotary_type == "proportional": + return ProportionalRotaryConfig(theta=self.theta, partial_rotary_factor=self.partial_rotary_factor) if self.rotary_type == "llama3": return Llama3RotaryConfig( theta=self.theta, @@ -121,6 +126,12 @@ def reference_angle_scales(self) -> torch.Tensor: base = self.theta ** -torch.arange(0, 1, 2 / self.head_size, dtype=torch.float64) if self.rotary_type in ("default", "2d"): return base + if self.rotary_type == "proportional": + rotary_pairs = round(self.head_size * self.partial_rotary_factor) // 2 + nope_pairs = self.head_size // 2 - rotary_pairs + if nope_pairs == 0: + return base + return torch.cat([base[:rotary_pairs], base.new_zeros(nope_pairs)]) if self.rotary_type == "llama3": high_freq_wavelength = self.original_context_length / self.high_frequency_factor low_freq_wavelength = self.original_context_length / self.low_frequency_factor @@ -200,6 +211,18 @@ def reference_output( for head_size in _head_sizes ] +for _head_size in _head_sizes: + for _factor in [0.25, 0.5, 0.75, 1.0]: + if round(_head_size * _factor) % 2 == 0 and round(_head_size * _factor) > 0: + _rotary_test_configs.append( + RotaryTestConfig( + name=f"proportional_{int(_factor * 100)}pct_h{_head_size}", + head_size=_head_size, + rotary_type="proportional", + partial_rotary_factor=_factor, + ) + ) + _sequence_lengths = [8, 24] diff --git a/tests/models/test_hf_roundtrip.py b/tests/models/test_hf_roundtrip.py index 3c472086b..f4baf7698 100644 --- a/tests/models/test_hf_roundtrip.py +++ b/tests/models/test_hf_roundtrip.py @@ -15,6 +15,8 @@ import torch from transformers import ( AutoConfig, + Gemma4ForCausalLM, + Gemma4TextConfig, LlamaConfig, LlamaForCausalLM, MistralConfig, @@ -37,6 +39,7 @@ from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import ( Apriel2TextCheckpointFormat, + Gemma4CheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -101,6 +104,25 @@ def make_model(self) -> PreTrainedModel: return self.model_class(self.config_class(**converted_config)) +@dataclasses.dataclass(frozen=True) +class Gemma4RoundtripCase(HFRoundtripCase): + """Gemma4: apply dim overrides directly without deriving head_dim from hidden_size.""" + + def make_model(self) -> PreTrainedModel: + config = self.config_class.from_pretrained(self.hf_model_name) + for key, value in self.dim_overrides.items(): + setattr(config, key, value) + config.max_position_embeddings = self.max_position_embeddings + if getattr(config, "layer_types", None) is not None: + n = config.num_hidden_layers + lt = config.layer_types + config.layer_types = (lt * ((n // len(lt)) + 1))[:n] + for key in self.delete_config_keys: + if hasattr(config, key): + delattr(config, key) + return self.model_class(config) + + _TINY_DIMS = { "hidden_size": 64, "num_attention_heads": 4, @@ -212,6 +234,33 @@ def make_model(self) -> PreTrainedModel: }, }, ), + Gemma4RoundtripCase( + name="gemma4", + hf_model_name="google/gemma-4-26B-A4B", + checkpoint_format=Gemma4CheckpointFormat, + model_class=Gemma4ForCausalLM, + config_class=Gemma4TextConfig, + dim_overrides={ + "hidden_size": 256, + "num_hidden_layers": 6, # 5 sliding + 1 full from real layer_types pattern + "num_attention_heads": 8, + "num_key_value_heads": 4, + "num_global_key_value_heads": 2, + "head_dim": 32, + "global_head_dim": 64, # real model has 2:1 ratio (256:512) + "intermediate_size": 256, + "moe_intermediate_size": 128, + "num_experts": 4, + "top_k_experts": 2, + "vocab_size": 384, + "hidden_size_per_layer_input": 0, + # use_bidirectional_attention="vision" in the real model is for multimodal vision tokens; + # Fast-LLM is text-only so the converter exports None — reset source to match. + "use_bidirectional_attention": None, + }, + max_position_embeddings=131072, # Gemma4TextConfig default; converter does not export this + delete_config_keys=("dtype",), # "bfloat16" in real config; metadata not preserved by converter + ), ] diff --git a/tests/test_config.py b/tests/test_config.py index 792eab077..3753f75d7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,7 @@ import collections import pathlib import subprocess +import sys import pytest import yaml @@ -18,7 +19,7 @@ def run_without_import(cmd: str): # Run the test in a separate process since lots of things are already imported in this one. repo_path = pathlib.Path(__file__).parents[1].resolve() command = [ - "python3", + sys.executable, "-c", "\n".join( [ diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 9cccf54cd..6a9b18268 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -20,6 +20,7 @@ Apriel2TextCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, + Gemma4CheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -963,6 +964,77 @@ def update_and_add_testing_config( ) +# Use init_1 for extra norms to keep the residual stream at small scale (~0.35). +# Default ones-init would normalize small layer outputs to unit scale before the residual add, +# growing the stream to ~1.5 per block and causing bf16 absolute errors to exceed compare_factor=2. +# query_norm/key_norm init_1 also keeps attention logits small (softmax_scale_power=0 otherwise +# amplifies bf16 relative error 3x). ParameterConfig.initialization is FieldHint.feature so +# this is invisible to the architecture comparison in test_load_pretrained. +_gemma4_block_overrides = { + "post_mixer_normalization": {"type": "rms_norm", "weight": init_1}, + "post_mlp_normalization": {"type": "rms_norm", "weight": init_1}, + "pre_mlp_normalization": {"type": "rms_norm", "weight": init_1}, + # Must match the gemma4 converter's lr_scale=0 — frozen params are packed at the end of the stage, + # so a mismatch produces a shifted shard layout that fails round-trip. + "output_scale": {"enabled": True, "lr_scale": 0}, +} +_gemma4_mixer_overrides = { + "softmax_scale_power": 0, + "query_norm": {"type": "rms_norm", "weight": init_1}, + "key_norm": {"type": "rms_norm", "weight": init_1}, + "value_norm": {"type": "fixed_rms_norm"}, +} + +update_and_add_testing_config( + # Tests Gemma4 converter: pattern decoder with alternating sliding/full attention, + # per-head norms (q/k/v), post-attention and post-MLP norms, embedding scale. + "llama", + "gemma4", + updates={ + ("model", "base_model", "tied_embedding_weight"): True, + ("model", "base_model", "embeddings", "embedding_scale"): 16.0, # sqrt(hidden_size=256); must match converter + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "sliding_attention": { + **copy.deepcopy(_llama_block), + **_gemma4_block_overrides, + "mixer": { + **copy.deepcopy(_llama_block["mixer"]), + **_gemma4_mixer_overrides, + "window_size": 128, + }, + }, + "full_attention": { + **copy.deepcopy(_llama_block), + **_gemma4_block_overrides, + "mixer": { + **copy.deepcopy(_llama_block["mixer"]), + **_gemma4_mixer_overrides, + "rotary": {"type": "proportional", "partial_rotary_factor": 0.25}, + }, + }, + }, + "pattern": ["sliding_attention", "full_attention"], + "num_blocks": 2, + }, + }, + megatron_args=None, + checkpoint_format=Gemma4CheckpointFormat, + compare_factor=5.0, # init_1 on post_mlp_norm makes its gradient tiny (~5e-6), hitting the fp16 rms_eps floor + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + skip_tests=("sdp", "ms"), + requires_cuda=False, +) + + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: models = request.config.getoption("--models")