From 1a219c45692da48f61b853af019a054d136282b6 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 18 Nov 2025 15:21:19 +0000 Subject: [PATCH 01/43] wip --- fast_llm/models/auto.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 7830c69a1..3f67fe710 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -4,6 +4,7 @@ from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip +from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.models.multimodal.config import MultiModalModelConfig, MultiModalTrainerConfig # isort: skip from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip From 5242eb6b0be0e755ae2d86ef1a7836bca0a97754 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 18 Nov 2025 21:23:36 +0000 Subject: [PATCH 02/43] added gdn --- .../layers/common/normalization/config.py | 11 + .../common/normalization/normalization.py | 42 ++ fast_llm/layers/ssm/config.py | 90 +++++ fast_llm/layers/ssm/gdn.py | 372 ++++++++++++++++++ fast_llm/models/auto.py | 7 +- 5 files changed, 521 insertions(+), 1 deletion(-) create mode 100644 fast_llm/layers/ssm/gdn.py diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index a80a19280..4ecb7a3be 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -127,3 +127,14 @@ def module_class(self): from fast_llm.layers.common.normalization.normalization import RMSNormalization return RMSNormalization + + +@config_class(dynamic_type={NormalizationConfig: "gated_rms_norm"}) +class GatedRMSNormalizationConfig(RMSNormalizationConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization.normalization import GatedRMSNormalization + + return GatedRMSNormalization diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index d0a5ab151..ec8a52e26 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -1,6 +1,7 @@ import abc import torch +import torch.nn.functional as F from fast_llm.config import Configurable from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ @@ -9,6 +10,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.normalization.config import ( + GatedRMSNormalizationConfig, LayerNormalizationConfig, NoNormalizationConfig, NormalizationConfig, @@ -33,6 +35,12 @@ _fast_normalization_available = False +try: + from fla.modules.fused_norm_gate import rms_norm_gated # noqa +except ImportError: + rms_norm_gated = None + + _PERSIST_LN_SIZES = ( 1024, 1536, @@ -292,3 +300,37 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) + + +class GatedRMSNormalization[ConfigType: GatedRMSNormalizationConfig](RMSNormalization[ConfigType], torch.nn.Module): + """ + A gated RMS normalization layer. + """ + + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) + + if rms_norm_gated is not None: + self._forward = self._forward_fused + else: + self._forward = self._forward + + def forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + return self._forward(input_.view(-1, *self._normalized_shape), gate).view_as(input_) + + def _forward_fused(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + return rms_norm_gated( + input_, + gate, + self.weight, + None, + activation="silu", + eps=self._config.epsilon, + residual=None, + prenorm=False, + residual_in_fp32=False, + ) + + def _forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + normalized = self.rmsnorm(input_) + return normalized * F.silu(gate) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index e541341e5..35ed6f6a8 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -4,7 +4,9 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert @@ -16,6 +18,94 @@ from fast_llm.tensor import ParameterMeta +@config_class(dynamic_type={MixerConfig: "gdn"}) +class GatedDeltaNetConfig(MixerConfig): + """ + Configuration for the gated DeltaNet mixer used in Qwen3Next style linear attention blocks. + """ + + _abstract = False + normalization: NormalizationConfig = Field( + desc="Configuration for the block normalization layers.", + hint=FieldHint.architecture, + ) + qkv_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces query, key, value and modulation vectors.", + hint=FieldHint.architecture, + ) + ba_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces the decay and beta terms.", + hint=FieldHint.architecture, + ) + convolution_layer: CausalConv1dConfig = Field( + desc="Depth-wise convolution applied to the concatenated QKV streams.", + hint=FieldHint.architecture, + ) + output_layer: AffineLinearConfig = Field( + desc="Output projection applied after the DeltaNet recurrence and gated RMS norm.", + hint=FieldHint.architecture, + ) + dt_bias_weight: ParameterConfig = Field( + desc="Parameter configuration for the DeltaNet time-step bias.", + hint=FieldHint.architecture, + ) + a_log_weight: ParameterConfig = Field( + desc="Parameter configuration for the DeltaNet decay rates.", + hint=FieldHint.architecture, + ) + + value_heads: int = Field( + default=16, + desc="Number of value heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + key_heads: int = Field( + default=8, + desc="Number of key heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + key_head_dim: int = Field( + default=64, + desc="Dimension of each key head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + value_head_dim: int = Field( + default=64, + desc="Dimension of each value head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + norm_epsilon: float = Field( + default=1e-6, + desc="Epsilon used by the gated RMS norm.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + use_qk_l2norm: bool = Field( + default=True, + desc="Apply L2 normalization on query/key vectors inside the Delta rule kernel.", + hint=FieldHint.architecture, + ) + activation: ActivationType = Field( + default=ActivationType.silu, + desc="Activation used after the convolution.", + hint=FieldHint.architecture, + ) + + def _validate(self) -> None: + super()._validate() + Assert.multiple(self.value_heads, self.key_heads) + + @property + def layer_class(self) -> "type": + from fast_llm.layers.ssm.gdn import GatedDeltaNet + + return GatedDeltaNet + + @config_class() class SSMConfig(MixerConfig): # Layers diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py new file mode 100644 index 000000000..a1a62a5a9 --- /dev/null +++ b/fast_llm/layers/ssm/gdn.py @@ -0,0 +1,372 @@ +import logging +import typing + +import torch +import torch.nn.functional as F + +from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.ssm.config import GatedDeltaNetConfig +from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.utils import div + +logger = logging.getLogger(__name__) + +try: + from fla.ops.gated_delta_rule import chunk_gated_delta_rule +except ImportError: + chunk_gated_delta_rule = None + + +is_fast_path_available = chunk_gated_delta_rule is not None + + +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + + +def torch_recurrent_gated_delta_rule( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + *, + use_qk_l2norm_in_kernel: bool, +) -> torch.Tensor: + """ + Simplified gated Delta rule used during training. + Args expect tensors shaped as (batch, heads, seq, dim) except for g/beta which are (batch, heads, seq). + """ + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = _l2norm(query, dim=-1) + key = _l2norm(key, dim=-1) + + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + beta = beta.to(torch.float32) + g = g.to(torch.float32) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + state = torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, device=key.device, dtype=key.dtype) + outputs = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim, device=value.device, dtype=value.dtype) + + for idx in range(sequence_length): + q_t = query[:, :, idx] + k_t = key[:, :, idx] + v_t = value[:, :, idx] + g_t = g[:, :, idx].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, idx].unsqueeze(-1) + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + outputs[:, :, idx] = (state * q_t.unsqueeze(-1)).sum(dim=-2) + + return outputs.to(initial_dtype), state + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = _l2norm(query, dim=-1, eps=1e-6) + key = _l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = ( + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = ( + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ) + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +# class _GatedRMSNorm(torch.nn.Module): +# def __init__(self, hidden_size: int, eps: float): +# super().__init__() +# self.weight = torch.nn.Parameter(torch.ones(hidden_size)) +# self.eps = eps + +# def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: +# dtype = hidden_states.dtype +# hidden_states = hidden_states.to(torch.float32) +# variance = hidden_states.pow(2).mean(dim=-1, keepdim=True) +# hidden_states = hidden_states * torch.rsqrt(variance + self.eps) +# hidden_states = self.weight * hidden_states.to(dtype) +# hidden_states = hidden_states * F.silu(gate.to(torch.float32)) +# return hidden_states.to(dtype) + + +class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias + ) + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._value_heads_dim = TensorDim( + "gdn_value_heads", self._config.value_heads, self._parallel_dim if self._config.value_heads > 1 else None + ) + self._key_heads_dim = TensorDim( + "gdn_key_heads", self._config.key_heads, self._parallel_dim if self._config.key_heads > 1 else None + ) + self._value_head_dim = TensorDim("gdn_value_head_dim", self._config.value_head_dim) + self._key_head_dim = TensorDim("gdn_key_head_dim", self._config.key_head_dim) + self._local_value_heads = self._value_heads_dim.size + self._local_key_heads = self._key_heads_dim.size + self._value_heads_per_key = div(self._local_value_heads, max(self._local_key_heads, 1)) + + query_dim = CompositeTensorDim("gdn_query", (self._key_heads_dim, self._key_head_dim)) + key_dim = CompositeTensorDim("gdn_key", (self._key_heads_dim, self._key_head_dim)) + value_dim = CompositeTensorDim("gdn_value", (self._value_heads_dim, self._value_head_dim)) + z_dim = CompositeTensorDim("gdn_z", (self._value_heads_dim, self._value_head_dim)) + qkvz_dim = ConcatenatedTensorDim("gdn_qkvz", (query_dim, key_dim, value_dim, z_dim)) + ba_dim = ConcatenatedTensorDim( + "gdn_ba", + ( + CompositeTensorDim("gdn_beta", (self._value_heads_dim,)), + CompositeTensorDim("gdn_alpha", (self._value_heads_dim,)), + ), + ) + + qkv_channels_dim = ConcatenatedTensorDim("gdn_qkv", (query_dim, key_dim, value_dim)) + + self.in_proj_qkvz = self._config.qkv_projection_layer.get_layer( + hidden_dim, + qkvz_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.in_proj_ba = self._config.ba_projection_layer.get_layer( + hidden_dim, + ba_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.convolution = self._config.convolution_layer.get_layer( + qkv_channels_dim, + default_add_bias=False, + default_activation=self._config.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.out_proj = self._config.output_layer.get_layer( + value_dim, + hidden_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( + (self._value_heads_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.A_log: ParameterMeta = self._config.a_log_weight.get_parameter( + (self._value_heads_dim,), + default_initialization=LambdaInitializer( + lambda _, tensor, generator: tensor.uniform_(0, 16, generator=generator).log_() + ), + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.norm = self._config.normalization.get_layer( + self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft + ) + # _GatedRMSNorm(self._config.value_head_dim, self._config.norm_epsilon) + self._use_qk_l2norm = self._config.use_qk_l2norm + + self._value_dim = value_dim + self._query_dim = query_dim + self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule + + if not is_fast_path_available: + logger.warning( + "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." + ) + + def _reshape_heads(self, tensor: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor: + batch, seq, _ = tensor.shape + return tensor.view(batch, seq, num_heads, head_dim) + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + sequence_first = kwargs[BlockKwargs.sequence_first] + if sequence_first: + hidden_states = input_.transpose(0, 1) + else: + hidden_states = input_ + + batch_size, sequence_length, _ = hidden_states.shape + qkvz = self.in_proj_qkvz(hidden_states) + ba = self.in_proj_ba(hidden_states) + key_size = self._query_dim.size + value_size = self._value_dim.size + query, key, value, z = torch.split(qkvz, (key_size, key_size, value_size, value_size), dim=-1) + beta, alpha = torch.split(ba, (self._local_value_heads, self._local_value_heads), dim=-1) + + query = self._reshape_heads(query, self._local_key_heads, self._config.key_head_dim) + key = self._reshape_heads(key, self._local_key_heads, self._config.key_head_dim) + value = self._reshape_heads(value, self._local_value_heads, self._config.value_head_dim) + z = self._reshape_heads(z, self._local_value_heads, self._config.value_head_dim) + + mixed_qkv = torch.cat( + ( + query.reshape(batch_size, sequence_length, -1), + key.reshape(batch_size, sequence_length, -1), + value.reshape(batch_size, sequence_length, -1), + ), + dim=-1, + ) + mixed_qkv = mixed_qkv.transpose(1, 2) + mixed_qkv = self.convolution(mixed_qkv) + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + ( + self._local_key_heads * self._config.key_head_dim, + self._local_key_heads * self._config.key_head_dim, + self._local_value_heads * self._config.value_head_dim, + ), + dim=-1, + ) + query = self._reshape_heads(query, self._local_key_heads, self._config.key_head_dim) + key = self._reshape_heads(key, self._local_key_heads, self._config.key_head_dim) + value = self._reshape_heads(value, self._local_value_heads, self._config.value_head_dim) + + beta = beta.view(batch_size, sequence_length, self._local_value_heads).sigmoid() + alpha = alpha.view(batch_size, sequence_length, self._local_value_heads) + dt_bias = self.dt_bias.to(hidden_states.dtype) + a_log = self.A_log.to(hidden_states.dtype) + g = -torch.exp(a_log) * F.softplus(alpha + dt_bias) + + if self._value_heads_per_key > 1: + query = query.repeat_interleave(self._value_heads_per_key, dim=2) + key = key.repeat_interleave(self._value_heads_per_key, dim=2) + + core_attn_out, _ = self.chunk_gated_delta_rule( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + g=g.permute(0, 2, 1), + beta=beta.permute(0, 2, 1), + use_qk_l2norm_in_kernel=self._use_qk_l2norm, + ) + + core_attn_out = core_attn_out.permute(0, 2, 1, 3).reshape( + batch_size, sequence_length, -1, self._config.value_head_dim + ) + z = z.reshape(batch_size, sequence_length, -1, self._config.value_head_dim) + norm_input = core_attn_out.reshape(-1, self._config.value_head_dim) + norm_gate = z.reshape(-1, self._config.value_head_dim) + norm_out = self.norm(norm_input, norm_gate).view(batch_size, sequence_length, -1) + output = self.out_proj(norm_out) + + if sequence_first: + output = output.transpose(0, 1) + return output + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # return ( + # self.in_proj_qkvz.get_compute_usage(input_, config) + # + self.in_proj_ba.get_compute_usage(input_, config) + # + self.out_proj.get_compute_usage(input_, config) + # ) + raise NotImplementedError() diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 3f67fe710..f7c34a973 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -3,7 +3,12 @@ """ from fast_llm.layers.attention.config import AttentionConfig # isort: skip -from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip +from fast_llm.layers.ssm.config import ( + MambaConfig, + Mamba2Config, + DiscreteMamba2Config, + GatedDeltaNetConfig, +) # isort: skip from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.models.multimodal.config import MultiModalModelConfig, MultiModalTrainerConfig # isort: skip From bec22ded712f7b8b721ad3bf1e5a7f14030e8328 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 19 Nov 2025 14:27:59 +0000 Subject: [PATCH 03/43] gdn layer --- fast_llm/layers/ssm/gdn.py | 178 ++++++++++++++----------------------- 1 file changed, 69 insertions(+), 109 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index a1a62a5a9..62360acc8 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -30,50 +30,6 @@ def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) -def torch_recurrent_gated_delta_rule( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - *, - use_qk_l2norm_in_kernel: bool, -) -> torch.Tensor: - """ - Simplified gated Delta rule used during training. - Args expect tensors shaped as (batch, heads, seq, dim) except for g/beta which are (batch, heads, seq). - """ - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - query = _l2norm(query, dim=-1) - key = _l2norm(key, dim=-1) - - query = query.to(torch.float32) - key = key.to(torch.float32) - value = value.to(torch.float32) - beta = beta.to(torch.float32) - g = g.to(torch.float32) - - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] - state = torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, device=key.device, dtype=key.dtype) - outputs = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim, device=value.device, dtype=value.dtype) - - for idx in range(sequence_length): - q_t = query[:, :, idx] - k_t = key[:, :, idx] - v_t = value[:, :, idx] - g_t = g[:, :, idx].exp().unsqueeze(-1).unsqueeze(-1) - beta_t = beta[:, :, idx].unsqueeze(-1) - state = state * g_t - kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) - delta = (v_t - kv_mem) * beta_t - state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) - outputs[:, :, idx] = (state * q_t.unsqueeze(-1)).sum(dim=-2) - - return outputs.to(initial_dtype), state - - def torch_chunk_gated_delta_rule( query, key, @@ -154,23 +110,11 @@ def torch_chunk_gated_delta_rule( return core_attn_out, last_recurrent_state -# class _GatedRMSNorm(torch.nn.Module): -# def __init__(self, hidden_size: int, eps: float): -# super().__init__() -# self.weight = torch.nn.Parameter(torch.ones(hidden_size)) -# self.eps = eps - -# def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: -# dtype = hidden_states.dtype -# hidden_states = hidden_states.to(torch.float32) -# variance = hidden_states.pow(2).mean(dim=-1, keepdim=True) -# hidden_states = hidden_states * torch.rsqrt(variance + self.eps) -# hidden_states = self.weight * hidden_states.to(dtype) -# hidden_states = hidden_states * F.silu(gate.to(torch.float32)) -# return hidden_states.to(dtype) - - class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]): + """ + Follows implementation here: https://github.com/huggingface/transformers/blob/a5c903f877fda21e739027eed133e03162eb7712/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L593 + """ + _config: ConfigType def __init__( @@ -265,11 +209,7 @@ def __init__( self.norm = self._config.normalization.get_layer( self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft ) - # _GatedRMSNorm(self._config.value_head_dim, self._config.norm_epsilon) - self._use_qk_l2norm = self._config.use_qk_l2norm - self._value_dim = value_dim - self._query_dim = query_dim self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule if not is_fast_path_available: @@ -277,9 +217,41 @@ def __init__( "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." ) - def _reshape_heads(self, tensor: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor: - batch, seq, _ = tensor.shape - return tensor.view(batch, seq, num_heads, head_dim) + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. + """ + + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self._local_key_heads, + 2 * self._config.key_head_dim + + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + self._local_key_heads, + 2 * self._local_value_heads // self._local_key_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + split_arg_list_qkvz = [ + self._config.key_head_dim, + self._config.key_head_dim, + (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + ] + split_arg_list_ba = [ + self._local_value_heads // self._local_key_heads, + self._local_value_heads // self._local_key_heads, + ] + query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) + b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) + z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) + b = b.reshape(b.size(0), b.size(1), self._local_value_heads) + a = a.reshape(a.size(0), a.size(1), self._local_value_heads) + return query, key, value, z, b, a def _forward( self, @@ -289,32 +261,22 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: sequence_first = kwargs[BlockKwargs.sequence_first] + + # TODO: do we need maksing of padding tokens? if sequence_first: hidden_states = input_.transpose(0, 1) else: hidden_states = input_ - batch_size, sequence_length, _ = hidden_states.shape - qkvz = self.in_proj_qkvz(hidden_states) - ba = self.in_proj_ba(hidden_states) - key_size = self._query_dim.size - value_size = self._value_dim.size - query, key, value, z = torch.split(qkvz, (key_size, key_size, value_size, value_size), dim=-1) - beta, alpha = torch.split(ba, (self._local_value_heads, self._local_value_heads), dim=-1) - - query = self._reshape_heads(query, self._local_key_heads, self._config.key_head_dim) - key = self._reshape_heads(key, self._local_key_heads, self._config.key_head_dim) - value = self._reshape_heads(value, self._local_value_heads, self._config.value_head_dim) - z = self._reshape_heads(z, self._local_value_heads, self._config.value_head_dim) - - mixed_qkv = torch.cat( - ( - query.reshape(batch_size, sequence_length, -1), - key.reshape(batch_size, sequence_length, -1), - value.reshape(batch_size, sequence_length, -1), - ), - dim=-1, + # batch_size, sequence_length, _ = hidden_states.shape + projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs x seq_len x (qkvz) + projected_states_ba = self.in_proj_ba(hidden_states) # bs x seq_len x (b a) + query, key, value, z, beta, alpha = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba ) + query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) + + mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = mixed_qkv.transpose(1, 2) mixed_qkv = self.convolution(mixed_qkv) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -327,37 +289,35 @@ def _forward( ), dim=-1, ) - query = self._reshape_heads(query, self._local_key_heads, self._config.key_head_dim) - key = self._reshape_heads(key, self._local_key_heads, self._config.key_head_dim) - value = self._reshape_heads(value, self._local_value_heads, self._config.value_head_dim) + query = query.reshape(query.shape[0], query.shape[1], -1, self._config.key_head_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self._config.key_head_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self._config.value_head_dim) - beta = beta.view(batch_size, sequence_length, self._local_value_heads).sigmoid() - alpha = alpha.view(batch_size, sequence_length, self._local_value_heads) - dt_bias = self.dt_bias.to(hidden_states.dtype) - a_log = self.A_log.to(hidden_states.dtype) - g = -torch.exp(a_log) * F.softplus(alpha + dt_bias) + beta = beta.sigmoid() + g = -torch.exp(self.A_log) * F.softplus(alpha + self.dt_bias) if self._value_heads_per_key > 1: query = query.repeat_interleave(self._value_heads_per_key, dim=2) key = key.repeat_interleave(self._value_heads_per_key, dim=2) core_attn_out, _ = self.chunk_gated_delta_rule( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - g=g.permute(0, 2, 1), - beta=beta.permute(0, 2, 1), - use_qk_l2norm_in_kernel=self._use_qk_l2norm, + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, ) - core_attn_out = core_attn_out.permute(0, 2, 1, 3).reshape( - batch_size, sequence_length, -1, self._config.value_head_dim - ) - z = z.reshape(batch_size, sequence_length, -1, self._config.value_head_dim) - norm_input = core_attn_out.reshape(-1, self._config.value_head_dim) - norm_gate = z.reshape(-1, self._config.value_head_dim) - norm_out = self.norm(norm_input, norm_gate).view(batch_size, sequence_length, -1) - output = self.out_proj(norm_out) + z_shape_og = z.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + output = self.out_proj(core_attn_out) if sequence_first: output = output.transpose(0, 1) From 7f7990983c064bb9a0627dbeb93ace5b28d87b58 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 19 Nov 2025 16:44:03 +0000 Subject: [PATCH 04/43] kda --- fast_llm/layers/ssm/config.py | 91 ++++++++++++ fast_llm/layers/ssm/kda.py | 268 ++++++++++++++++++++++++++++++++++ 2 files changed, 359 insertions(+) create mode 100644 fast_llm/layers/ssm/kda.py diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 35ed6f6a8..95ef9bed2 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -106,6 +106,97 @@ def layer_class(self) -> "type": return GatedDeltaNet +@config_class(dynamic_type={MixerConfig: "kda"}) +class KimiDeltaAttentionConfig(MixerConfig): + """ + Configuration for the KimiDeltaAttention mixer inspired by the Kimi Linear models. + """ + + _abstract = False + normalization: NormalizationConfig = Field( + desc="Configuration for the gated normalization applied to the KDA output.", + hint=FieldHint.architecture, + ) + q_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces query vectors.", + hint=FieldHint.architecture, + ) + k_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces key vectors.", + hint=FieldHint.architecture, + ) + v_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces value vectors.", + hint=FieldHint.architecture, + ) + f_a_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the Delta gating pre-activation.", + hint=FieldHint.architecture, + ) + f_b_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the Delta gating expansion.", + hint=FieldHint.architecture, + ) + g_a_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the output gating pre-activation.", + hint=FieldHint.architecture, + ) + g_b_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the output gating expansion.", + hint=FieldHint.architecture, + ) + beta_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces the Beta gate.", + hint=FieldHint.architecture, + ) + output_projection_layer: AffineLinearConfig = Field( + desc="Projection applied after the Delta recurrence and gated normalization.", + hint=FieldHint.architecture, + ) + convolution_layer: CausalConv1dConfig = Field( + desc="Depth-wise convolution applied independently on each Q, K and V stream.", + hint=FieldHint.architecture, + ) + dt_bias_weight: ParameterConfig = Field( + desc="Parameter configuration for the Delta gate bias.", + hint=FieldHint.architecture, + ) + a_log_weight: ParameterConfig = Field( + desc="Parameter configuration for the decay rates.", + hint=FieldHint.architecture, + ) + + heads: int = Field( + default=16, + desc="Number of attention heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + head_dim: int = Field( + default=64, + desc="Dimension of each head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + recurrent_threshold: int = Field( + default=64, + desc="Switch to the fused recurrent kernel below this sequence length.", + hint=FieldHint.performance, + valid=check_field(Assert.gt, 0), + ) + use_qk_l2norm: bool = Field( + default=True, + desc="Apply L2 normalization to query/key vectors inside the Delta kernel.", + hint=FieldHint.architecture, + ) + + @property + def layer_class(self) -> "type": + from fast_llm.layers.ssm.kda import KimiDeltaAttention + + return KimiDeltaAttention + + @config_class() class SSMConfig(MixerConfig): # Layers diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py new file mode 100644 index 000000000..78cabe8da --- /dev/null +++ b/fast_llm/layers/ssm/kda.py @@ -0,0 +1,268 @@ +import logging +import typing + +import torch + +from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig +from fast_llm.tensor import ParameterMeta, TensorMeta + +logger = logging.getLogger(__name__) + +try: + from fla.ops.kda import chunk_kda + from fla.ops.kda.gate import fused_kda_gate +except ImportError: + chunk_kda = None + fused_kda_gate = None + + +class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): + """ + Implementation of the Kimi Delta Attention mixer. + Reference Implementation: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/modeling_kimi.py + """ + + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias + ) + if chunk_kda is None or fused_kda_gate is None: + raise ImportError( + "KimiDeltaAttention requires the `fla-core` package. " + "Please install it with `pip install -U fla-core`." + ) + + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._heads_dim = TensorDim( + "kda_heads", self._config.heads, self._parallel_dim if self._config.heads > 1 else None + ) + self._head_dim = TensorDim("kda_head_dim", self._config.head_dim) + self._projection_dim = CompositeTensorDim("kda_projection", (self._heads_dim, self._head_dim)) + self._local_heads = self._heads_dim.size + self._projection_size = self._projection_dim.size + + init = init_normal_(std=self._hidden_size**-0.5) + self.q_proj = self._config.q_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.k_proj = self._config.k_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.v_proj = self._config.v_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.q_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.k_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.v_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.f_a_proj = self._config.f_a_projection_layer.get_layer( + hidden_dim, + self._head_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.f_b_proj = self._config.f_b_projection_layer.get_layer( + self._head_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.g_a_proj = self._config.g_a_projection_layer.get_layer( + hidden_dim, + self._head_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.g_b_proj = self._config.g_b_projection_layer.get_layer( + self._head_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.beta_proj = self._config.beta_projection_layer.get_layer( + hidden_dim, + self._heads_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.o_proj = self._config.output_projection_layer.get_layer( + self._projection_dim, + hidden_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( + (self._projection_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.a_log: ParameterMeta = self._config.a_log_weight.get_parameter( + (self._heads_dim,), + default_initialization=LambdaInitializer( + lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_() + ), + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.norm = self._config.normalization.get_layer( + self._head_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module) -> torch.Tensor: + """ + Applies convolution. + Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one kjust uses causal_conv1danyways. + TODO: make sure varlen is supported correctly. + """ + tensor = tensor.transpose(1, 2).contiguous() + tensor = conv(tensor) + return tensor.transpose(1, 2).contiguous() + + def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: + tensor = tensor.contiguous() + # since head_dim is the same vor k,q and v + # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) + + def _get_dt_bias(self) -> torch.Tensor: + return self.dt_bias.view(1, 1, self._local_heads, self._config.head_dim) + + def _get_a_log(self) -> torch.Tensor: + return self.a_log.view(1, 1, self._local_heads, 1) + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # TODO: make sure varlen is supported + # TODO: make sure we dont need to mask padding tokens in training + # TODO: make sure sequence first is handdled correctly + sequence_first = kwargs[BlockKwargs.sequence_first] + hidden_states = input_.transpose(0, 1) if sequence_first else input_ + batch_size, sequence_length, _ = hidden_states.shape + residual_dtype = hidden_states.dtype + + q = self._apply_conv(self.q_proj(hidden_states), self.q_conv) + k = self._apply_conv(self.k_proj(hidden_states), self.k_conv) + v = self._apply_conv(self.v_proj(hidden_states), self.v_conv) + + g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) + g_kernel = fused_kda_gate(g_kernel, self._get_a_log(), self._config.head_dim, g_bias=self._get_dt_bias()) + + beta = torch.sigmoid(self.beta_proj(hidden_states).float()) + + q = self._reshape_heads(q) + k = self._reshape_heads(k) + v = self._reshape_heads(v) + # currently on supports Ampere??? + attn_out, _ = chunk_kda( + q=q, + k=k, + v=v, + g=g_kernel, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=self._config.use_qk_l2norm, + cu_seqlens=None, + ) + + attn_out = attn_out.to(residual_dtype) + attn_out = self._reshape_heads(attn_out) + + g_out = self._reshape_heads(self.g_b_proj(self.g_a_proj(hidden_states))) + + attn_out = attn_out.reshape(-1, self._config.head_dim) + g_out = g_out.reshape(-1, self._config.head_dim) + attn_out = self.norm(attn_out, g_out) + attn_out = attn_out.view(batch_size, sequence_length, self._projection_size) + attn_out = self.o_proj(attn_out) + + if sequence_first: + attn_out = attn_out.transpose(0, 1) + + return attn_out + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + raise NotImplementedError() From 8636f092ba89738a98e3c43e875ca38ce23b7d32 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 19 Nov 2025 21:23:25 +0000 Subject: [PATCH 05/43] wip --- fast_llm/layers/ssm/kda.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 78cabe8da..f6f776541 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -202,12 +202,6 @@ def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) - def _get_dt_bias(self) -> torch.Tensor: - return self.dt_bias.view(1, 1, self._local_heads, self._config.head_dim) - - def _get_a_log(self) -> torch.Tensor: - return self.a_log.view(1, 1, self._local_heads, 1) - def _forward( self, input_: torch.Tensor, @@ -228,14 +222,14 @@ def _forward( v = self._apply_conv(self.v_proj(hidden_states), self.v_conv) g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) - g_kernel = fused_kda_gate(g_kernel, self._get_a_log(), self._config.head_dim, g_bias=self._get_dt_bias()) + g_kernel = fused_kda_gate(g_kernel, self.a_log, self._config.head_dim, g_bias=self.dt_bias) beta = torch.sigmoid(self.beta_proj(hidden_states).float()) q = self._reshape_heads(q) k = self._reshape_heads(k) v = self._reshape_heads(v) - # currently on supports Ampere??? + # need to install nightly triton for now attn_out, _ = chunk_kda( q=q, k=k, @@ -244,7 +238,7 @@ def _forward( beta=beta, initial_state=None, output_final_state=False, - use_qk_l2norm_in_kernel=self._config.use_qk_l2norm, + use_qk_l2norm_in_kernel=True, cu_seqlens=None, ) From a20c9586387d5a9605c90ac345f1515c99f93017 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 20 Nov 2025 01:01:18 +0000 Subject: [PATCH 06/43] convertion kda --- fast_llm/layers/ssm/config.py | 31 +++---- fast_llm/layers/ssm/kda.py | 6 +- fast_llm/models/gpt/conversion/apriel.py | 113 ++++++++++++++++++++++- 3 files changed, 127 insertions(+), 23 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 95ef9bed2..b8a5a64c9 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -6,7 +6,7 @@ from fast_llm.engine.config_utils.parameter import ParameterConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig -from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig, NormalizationConfig from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert @@ -84,11 +84,6 @@ class GatedDeltaNetConfig(MixerConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - use_qk_l2norm: bool = Field( - default=True, - desc="Apply L2 normalization on query/key vectors inside the Delta rule kernel.", - hint=FieldHint.architecture, - ) activation: ActivationType = Field( default=ActivationType.silu, desc="Activation used after the convolution.", @@ -113,7 +108,7 @@ class KimiDeltaAttentionConfig(MixerConfig): """ _abstract = False - normalization: NormalizationConfig = Field( + normalization: GatedRMSNormalizationConfig = Field( desc="Configuration for the gated normalization applied to the KDA output.", hint=FieldHint.architecture, ) @@ -178,17 +173,6 @@ class KimiDeltaAttentionConfig(MixerConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - recurrent_threshold: int = Field( - default=64, - desc="Switch to the fused recurrent kernel below this sequence length.", - hint=FieldHint.performance, - valid=check_field(Assert.gt, 0), - ) - use_qk_l2norm: bool = Field( - default=True, - desc="Apply L2 normalization to query/key vectors inside the Delta kernel.", - hint=FieldHint.architecture, - ) @property def layer_class(self) -> "type": @@ -196,6 +180,17 @@ def layer_class(self) -> "type": return KimiDeltaAttention + def _validate(self) -> None: + with self._set_implicit_default(): + if "epsilon" not in self.normalization._explicit_fields: + self.normalization.epsilon = 1.0e-5 + if "activation" not in self.convolution_layer._explicit_fields: + self.convolution_layer.activation = "silu" + if "kernel_size" not in self.convolution_layer._explicit_fields: + self.convolution_layer.kernel_size = 4 + + super()._validate() + @config_class() class SSMConfig(MixerConfig): diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index f6f776541..1ce6bed76 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -172,7 +172,7 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - self.a_log: ParameterMeta = self._config.a_log_weight.get_parameter( + self.A_log: ParameterMeta = self._config.a_log_weight.get_parameter( (self._heads_dim,), default_initialization=LambdaInitializer( lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_() @@ -210,8 +210,8 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: # TODO: make sure varlen is supported - # TODO: make sure we dont need to mask padding tokens in training # TODO: make sure sequence first is handdled correctly + # TODO: make sure we dont need to mask padding tokens in training sequence_first = kwargs[BlockKwargs.sequence_first] hidden_states = input_.transpose(0, 1) if sequence_first else input_ batch_size, sequence_length, _ = hidden_states.shape @@ -222,7 +222,7 @@ def _forward( v = self._apply_conv(self.v_proj(hidden_states), self.v_conv) g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) - g_kernel = fused_kda_gate(g_kernel, self.a_log, self._config.head_dim, g_bias=self.dt_bias) + g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) beta = torch.sigmoid(self.beta_proj(hidden_states).float()) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index e16eac4de..215cc5257 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,7 +8,13 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config +from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.layers.ssm.config import ( + DiscreteMamba2Config, + GatedDeltaNetConfig, + KimiDeltaAttentionConfig, + Mamba2Config, +) from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters @@ -224,7 +230,102 @@ def get_converters( ] -class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): +class AprielMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + + +class GatedDeltaNetConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "type": "gated_delta_net", + "value_heads": config["linear_attn_config"]["gdn_value_head_dim"], + "key_heads": config["linear_attn_config"]["gdn_num_key_heads"], + "key_head_dim": config["linear_attn_config"]["gdn_key_head_dim"], + "value_head_dim": config["linear_attn_config"]["value_head_dim"], + "convolution_layer": { + "kernel_size": config["linear_attn_config"]["gdn_linear_conv_kernel_size"], + }, + } + + @classmethod + def export_config(cls, config: GatedDeltaNetConfig) -> dict: + return { + "linear_attn_config": { + "gdn_num_value_heads": config.value_heads, + "gdn_num_key_heads": config.key_heads, + "gdn_key_head_dim": config.key_head_dim, + "gdn_value_head_dim": config.value_head_dim, + "gdn_linear_conv_kernel_size": config.convolution_layer.kernel_size, + }, + } + + @classmethod + def get_converters( + cls, + config: KimiDeltaAttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj_qkvz", + f"{hf_prefix}.in_proj_qkvz", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj_ba", + f"{hf_prefix}.in_proj_ba", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.convolution", + f"{hf_prefix}.convolution", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.out_proj", + f"{hf_prefix}.out_proj", + False, + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.A_log", + f"{hf_prefix}.A_log", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.dt_bias", + f"{hf_prefix}.dt_bias", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.norm", + f"{hf_prefix}.norm", + False, + drop_on_export=drop_on_export, + ), + ] + + +class AprielBlockConverterBase(MistralBlockConverter): + mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter + + +class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" @@ -234,16 +335,24 @@ class AprielMamba2BlockConverter(MistralBlockConverter): hf_mixer_name: typing.ClassVar[str] = "mixer" +class AprielGatedDeltaNetBlockConverter(AprielBlockConverterBase): + mixer_converter_class: typing.ClassVar[type[GatedDeltaNetConverter]] = GatedDeltaNetConverter + hf_mixer_name: typing.ClassVar[str] = "mixer" + + class AprielBlockConverter: layout_names = { AttentionConfig: "t", Mamba2Config: "m2", DiscreteMamba2Config: "m2d", + KimiDeltaAttentionConfig: "kda", + GatedDeltaNetConfig: "gdn", } _converter_classes = { AttentionConfig: MistralBlockConverter, Mamba2Config: AprielMamba2BlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, + GatedDeltaNetConfig: AprielGatedDeltaNetBlockConverter, } _config_classes = {value: key for key, value in layout_names.items()} From 8ac5167f62fe047ff2004093df8fa72f91b5e19c Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 21 Nov 2025 16:26:07 +0000 Subject: [PATCH 07/43] tp and sequence tp --- fast_llm/layers/ssm/config.py | 15 ++++++++++-- fast_llm/layers/ssm/gdn.py | 21 ++++++++++------- fast_llm/layers/ssm/kda.py | 44 +++++++++++++++++++++++++---------- 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index b8a5a64c9..29f66c8be 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -6,7 +6,7 @@ from fast_llm.engine.config_utils.parameter import ParameterConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig -from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig, NormalizationConfig +from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert @@ -25,7 +25,7 @@ class GatedDeltaNetConfig(MixerConfig): """ _abstract = False - normalization: NormalizationConfig = Field( + normalization: GatedRMSNormalizationConfig = Field( desc="Configuration for the block normalization layers.", hint=FieldHint.architecture, ) @@ -100,6 +100,17 @@ def layer_class(self) -> "type": return GatedDeltaNet + def _validate(self) -> None: + with self._set_implicit_default(): + if "epsilon" not in self.normalization._explicit_fields: + self.normalization.epsilon = 1.0e-5 + if "activation" not in self.convolution_layer._explicit_fields: + self.convolution_layer.activation = "silu" + if "kernel_size" not in self.convolution_layer._explicit_fields: + self.convolution_layer.kernel_size = 4 + + super()._validate() + @config_class(dynamic_type={MixerConfig: "kda"}) class KimiDeltaAttentionConfig(MixerConfig): diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 62360acc8..d07cb5e21 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -113,6 +113,9 @@ def torch_chunk_gated_delta_rule( class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]): """ Follows implementation here: https://github.com/huggingface/transformers/blob/a5c903f877fda21e739027eed133e03162eb7712/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L593 + - For tensor parallel implementtion (no sequnece prallel): we scatter teh heads accross ranks. + - Sequence Tensor parallel: in_proj_qkvz all reduces across sequence dim. --> each rank performs work on full sequence but only a subset of heads (standrd TP). + """ _config: ConfigType @@ -261,16 +264,18 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: sequence_first = kwargs[BlockKwargs.sequence_first] - - # TODO: do we need maksing of padding tokens? - if sequence_first: - hidden_states = input_.transpose(0, 1) - else: - hidden_states = input_ + # in sequence parallel TP the input here is already scattered across sequence dimension + # TODO: do we need masking of padding tokens? + # TODO: make sure varlen is supported + hidden_states = input_ # batch_size, sequence_length, _ = hidden_states.shape projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs x seq_len x (qkvz) projected_states_ba = self.in_proj_ba(hidden_states) # bs x seq_len x (b a) + if sequence_first: + projected_states_qkvz = projected_states_qkvz.transpose(0, 1) + projected_states_ba = projected_states_ba.transpose(0, 1) + query, key, value, z, beta, alpha = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) @@ -317,10 +322,10 @@ def _forward( core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + if sequence_first: + core_attn_out = core_attn_out.transpose(0, 1) output = self.out_proj(core_attn_out) - if sequence_first: - output = output.transpose(0, 1) return output def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 1ce6bed76..9f85cd069 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -115,7 +115,7 @@ def __init__( self._head_dim, default_weight_initialization=init, default_add_bias=False, - sequence_parallel=self._sequence_parallel, + sequence_parallel=False, # self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, ) @@ -133,7 +133,7 @@ def __init__( self._head_dim, default_weight_initialization=init, default_add_bias=False, - sequence_parallel=self._sequence_parallel, + sequence_parallel=False, # self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, ) @@ -210,25 +210,46 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: # TODO: make sure varlen is supported - # TODO: make sure sequence first is handdled correctly # TODO: make sure we dont need to mask padding tokens in training sequence_first = kwargs[BlockKwargs.sequence_first] - hidden_states = input_.transpose(0, 1) if sequence_first else input_ - batch_size, sequence_length, _ = hidden_states.shape + hidden_states = input_ + residual_dtype = hidden_states.dtype - q = self._apply_conv(self.q_proj(hidden_states), self.q_conv) - k = self._apply_conv(self.k_proj(hidden_states), self.k_conv) - v = self._apply_conv(self.v_proj(hidden_states), self.v_conv) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + if sequence_first: + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + q = self._apply_conv(q, self.q_conv) + k = self._apply_conv(k, self.k_conv) + v = self._apply_conv(v, self.v_conv) + + if sequence_first: + _, batch_size, _ = hidden_states.shape + sequence_length = q.size(1) + # hidden_states = gather_op(hidden_states, self._distributed.tensor_group, dim=0, async_op=False).transpose( + # 0, 1 + # ) + # hidden_states = hidden_states.transpose(0, 1) + else: + batch_size, sequence_length, _ = hidden_states.shape g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) + if sequence_first: + g_kernel = g_kernel.transpose(0, 1) g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) beta = torch.sigmoid(self.beta_proj(hidden_states).float()) - q = self._reshape_heads(q) k = self._reshape_heads(k) v = self._reshape_heads(v) + if sequence_first: + beta = beta.transpose(0, 1) + # need to install nightly triton for now attn_out, _ = chunk_kda( q=q, @@ -245,16 +266,15 @@ def _forward( attn_out = attn_out.to(residual_dtype) attn_out = self._reshape_heads(attn_out) - g_out = self._reshape_heads(self.g_b_proj(self.g_a_proj(hidden_states))) + g_out = self._reshape_heads(self.g_b_proj(self.g_a_proj(hidden_states))) # bs x seq x n_local_heads x head dim attn_out = attn_out.reshape(-1, self._config.head_dim) g_out = g_out.reshape(-1, self._config.head_dim) attn_out = self.norm(attn_out, g_out) attn_out = attn_out.view(batch_size, sequence_length, self._projection_size) - attn_out = self.o_proj(attn_out) - if sequence_first: attn_out = attn_out.transpose(0, 1) + attn_out = self.o_proj(attn_out) return attn_out From f1a51f2754e90bc6982d0ee3edbded8010416460 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 22 Nov 2025 00:19:41 +0000 Subject: [PATCH 08/43] varlen kda --- fast_llm/layers/common/linear/convolution.py | 3 +- fast_llm/layers/ssm/config.py | 6 + fast_llm/layers/ssm/kda.py | 114 ++++++-- tests/test_ssm_varlen.py | 259 +++++++++++++++++++ 4 files changed, 355 insertions(+), 27 deletions(-) create mode 100644 tests/test_ssm_varlen.py diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index 6281348e1..2f682c460 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -45,12 +45,13 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: )[..., : input_.size(1)] ) - def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor: + def _forward_causal_conv1d(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: return _causal_conv1d_fn( input_, self.weight.squeeze(1), self.bias, activation=(None if self._activation == ActivationType.identity else self._activation.value), + **kwargs, ) def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 29f66c8be..2fa90aff9 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -5,6 +5,7 @@ from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig from fast_llm.layers.decoder.config import MixerConfig @@ -18,6 +19,11 @@ from fast_llm.tensor import ParameterMeta +class LinearAttentionKwargs(BlockKwargs): + cu_seqlens = "cu_seqlens" + seq_idx = "seq_idx" + + @config_class(dynamic_type={MixerConfig: "gdn"}) class GatedDeltaNetConfig(MixerConfig): """ diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 9f85cd069..b14fd4592 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -2,6 +2,7 @@ import typing import torch +from einops import rearrange, repeat from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ @@ -10,7 +11,7 @@ from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias -from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs from fast_llm.tensor import ParameterMeta, TensorMeta logger = logging.getLogger(__name__) @@ -23,6 +24,16 @@ fused_kda_gate = None +def index_first_axis(x, indices): + other_shape = x.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(x, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): """ Implementation of the Kimi Delta Attention mixer. @@ -186,14 +197,16 @@ def __init__( peft=self._peft, ) - def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module) -> torch.Tensor: + def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module, seq_idx: torch.Tensor = None) -> torch.Tensor: """ Applies convolution. - Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one kjust uses causal_conv1danyways. - TODO: make sure varlen is supported correctly. + Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one just uses causal_conv1d anyways. + Varlen: + - seq. idx are only suppored in channel last layout, i.e. no transpose """ - tensor = tensor.transpose(1, 2).contiguous() - tensor = conv(tensor) + tensor = rearrange(tensor, "b t d -> b d t") + # tensor = tensor.transpose(1, 2).contiguous() if seq_idx is None else tensor.transpose(1, 2) + tensor = conv(tensor, seq_idx=seq_idx) return tensor.transpose(1, 2).contiguous() def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: @@ -210,37 +223,44 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: # TODO: make sure varlen is supported - # TODO: make sure we dont need to mask padding tokens in training + # TODO: do we need to deal with padding tokens? sequence_first = kwargs[BlockKwargs.sequence_first] hidden_states = input_ + cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) + seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) + # TODO: can be made more efficeint by rearranging hidden states directly residual_dtype = hidden_states.dtype q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) + if sequence_first: + # make bs first dim again q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - q = self._apply_conv(q, self.q_conv) - k = self._apply_conv(k, self.k_conv) - v = self._apply_conv(v, self.v_conv) + batch_size, sequence_length, _ = q.size() - if sequence_first: - _, batch_size, _ = hidden_states.shape - sequence_length = q.size(1) - # hidden_states = gather_op(hidden_states, self._distributed.tensor_group, dim=0, async_op=False).transpose( - # 0, 1 - # ) - # hidden_states = hidden_states.transpose(0, 1) - else: - batch_size, sequence_length, _ = hidden_states.shape + # work with bs = 1 to make sure varlen works correctly, only needed if micro batch size is > 1 + # can this be applied once to hidden state only? pr + q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) + k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) + v = rearrange(v, "b s ... -> (b s) ...").unsqueeze(0) + + # because we use cu_seqlens, chunk_kda requires batch size to be 1 (flatten, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303) + # similarly to ShortConvolution from fla we already operate on flattened batches here (https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914) + q = self._apply_conv(q, self.q_conv, seq_idx=seq_idx) + k = self._apply_conv(k, self.k_conv, seq_idx=seq_idx) + v = self._apply_conv(v, self.v_conv, seq_idx=seq_idx) g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) if sequence_first: g_kernel = g_kernel.transpose(0, 1) + g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) + g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) beta = torch.sigmoid(self.beta_proj(hidden_states).float()) @@ -249,8 +269,10 @@ def _forward( v = self._reshape_heads(v) if sequence_first: beta = beta.transpose(0, 1) + beta = rearrange(beta, "b s h -> (b s) h").unsqueeze(0) - # need to install nightly triton for now + # need to install nightly triton for this to work on H100, see https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md + # cu_seqlens requires batch ssize to be 1, i.e. flattened bacthes attn_out, _ = chunk_kda( q=q, k=k, @@ -260,18 +282,19 @@ def _forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, - cu_seqlens=None, + cu_seqlens=cu_seqlens, ) attn_out = attn_out.to(residual_dtype) - attn_out = self._reshape_heads(attn_out) - g_out = self._reshape_heads(self.g_b_proj(self.g_a_proj(hidden_states))) # bs x seq x n_local_heads x head dim + g_out = self.g_b_proj(self.g_a_proj(hidden_states)) # bs x seq x n_local_heads x head dim + g_out = self._reshape_heads(g_out) + if sequence_first: + g_out = g_out.transpose(0, 1) - attn_out = attn_out.reshape(-1, self._config.head_dim) - g_out = g_out.reshape(-1, self._config.head_dim) + attn_out = rearrange(attn_out.squeeze(0), "(b s) h d -> b s h d", b=batch_size, s=sequence_length) attn_out = self.norm(attn_out, g_out) - attn_out = attn_out.view(batch_size, sequence_length, self._projection_size) + attn_out = rearrange(attn_out, "b s h d -> b s (h d)") if sequence_first: attn_out = attn_out.transpose(0, 1) attn_out = self.o_proj(attn_out) @@ -280,3 +303,42 @@ def _forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() + + def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + if sequence_lengths is None: + raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") + sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + seqlens = torch.cat(sequence_lengths) + cu_seqlens = torch.cat( + ( + torch.zeros(1, dtype=torch.long, device=batch.device), + torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), + ) + ) + # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 + # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 + kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens + # seq_idx has to be (bs, seqlen), but bs is forced to 1 + kwargs[LinearAttentionKwargs.seq_idx] = ( + ( + torch.cat( + [ + torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) + for n in (torch.diff(cu_seqlens).to(torch.int32)) + ], + dim=0, + ) + .eq(0) + .cumsum(0) + - 1 + ) + .to(torch.int32) + .unsqueeze(0) + ) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if LinearAttentionKwargs.sequence_lengths in kwargs: + # TODO: packing is enabled by default, i.e. its always used? + # only get here when cross_document_attention is False + self._preprocess_for_varlen(batch, kwargs) diff --git a/tests/test_ssm_varlen.py b/tests/test_ssm_varlen.py new file mode 100644 index 000000000..9ca491e3d --- /dev/null +++ b/tests/test_ssm_varlen.py @@ -0,0 +1,259 @@ +import inspect +import itertools + +import pytest +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.ssm import kda as kda_module +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs + +# from mamba2 import NemotronHMamba2 + + +_mamba_varlen = False +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa + + _mamba_available = True + sig = inspect.signature(selective_scan_fn) + if "position_indices" in sig.parameters: + _mamba_varlen = True + else: + _mamba_varlen = False + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + +except (ImportError, RuntimeError): + _mamba_available = False + + +@pytest.fixture +def distributed_config(): + return DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + + +@pytest.fixture +def distributed(distributed_config): + return Distributed(config=distributed_config) + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +def unpack(packed_hidden_states, cu_seqlens): + batch_size = packed_hidden_states.shape[0] + package_num = cu_seqlens.shape[0] - 1 + seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + hidden_dim = packed_hidden_states.shape[2] + hidden_states = torch.zeros( + package_num * batch_size, + seq_len, + hidden_dim, + dtype=packed_hidden_states.dtype, + device=packed_hidden_states.device, + ) + for j in range(batch_size): + for i in range(package_num): + line = j * package_num + i + hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ + j, cu_seqlens[i] : cu_seqlens[i + 1], : + ] + return hidden_states + + +def pack(hidden_states, cu_seqlens, batch_size): + package_num, seq_len, hidden_dim = hidden_states.shape + seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] + seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) + indices_3d = ( + torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) + ) + mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) + packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) + return packed_hidden_states + + +def generate_random_cu_seqlens(seq_len, packages_num=2): + if packages_num < 1: + raise ValueError("packages_num must be at least 1") + + # base size of each chunk, and how many get an extra token + base, rem = divmod(seq_len, packages_num) + # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] + lengths = [base + 1 if i < rem else base for i in range(packages_num)] + + # split points exclude the final cumulative (seq_len) + split_points = list(itertools.accumulate(lengths))[:-1] + + cu_seqlens = [0] + split_points + [seq_len] + # cu_seqlens = split_points # + [seq_len] + + # index: for each chunk, we emit 0,1,...,length-1 + index = [] + for length in lengths: + index.extend(range(length)) + + # sanity check + assert len(cu_seqlens) - 1 == packages_num + assert sum(lengths) == seq_len + assert len(index) == seq_len + + return cu_seqlens, index + + +def _materialize_kda_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: + """ + Materialize meta parameters on the requested device for KDA mixer layers. + """ + for name, param in module.named_parameters(): + if param.device.type != "meta": + continue + param_data = torch.empty_like(param, device=device) + param.init_parameter(param_data, distributed) + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + target = module + if module_path is not None: + for part in module_path.split("."): + target = getattr(target, part) + new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + new_param.grad = None + new_param.grad_buffer = torch.zeros_like(param_data) + new_param.param_grad_is_zero = True + target._parameters[param_name] = new_param + + +def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: + # ParameterMeta stores grads in grad_buffer; fall back to .grad otherwise. + return param.grad_buffer if hasattr(param, "grad_buffer") and param.grad_buffer is not None else param.grad + + +@pytest.mark.slow +@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA varlen needs CUDA") +@pytest.mark.skipif( + kda_module.chunk_kda is None or kda_module.fused_kda_gate is None, + reason="KDA fused kernels not available", +) +def test_kda_varlen_stacking_equivalence(distributed_config, distributed): + """ + Check that KDA forward/backward match with and without stacking using the real kernels. + """ + device = torch.device("cuda") + dtype = torch.float16 + heads, head_dim = 2, 16 + hidden_size = heads * head_dim + + config = KimiDeltaAttentionConfig(heads=heads, head_dim=head_dim) + hidden_dim = TensorDim("hidden", hidden_size) + kda_packed = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + kda_ref = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + kda_packed.setup(distributed) + kda_ref.setup(distributed) + _materialize_kda_tensors(kda_packed, distributed, device) + _materialize_kda_tensors(kda_ref, distributed, device) + kda_ref.load_state_dict(kda_packed.state_dict()) + kda_packed.to(device=device, dtype=dtype) + kda_ref.to(device=device, dtype=dtype) + + batch_size = 2 # cu_seqlens path requires flattened batch + seq_len = 15 + packages_num = torch.randint(2, 5, (1, batch_size))[0] # randomize packages num between 2 and 4 + lengths = [ + torch.tensor( + generate_random_cu_seqlens(seq_len, packages_num=packages_num[i].item())[0], + device=device, + dtype=torch.long, + ).diff() + for i in range(batch_size) + ] + + # lengths = torch.tensor(cu_seqlens, device=device, dtype=torch.long)#.diff() + # total_tokens = lengths.sum().item() + packed = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) + + kwargs_packed = { + LinearAttentionKwargs.sequence_lengths: lengths, + BlockKwargs.sequence_first: False, + BlockKwargs.hidden_dims: (hidden_dim,), + # BlockKwargs.sequence_q_dim: TensorDim("sequence_q", lengths.sum().item()), + } + # Use the layer's preprocess to construct cu_seqlens/seq_idx the same way as the implementation. + kda_packed.preprocess(packed, kwargs_packed) + + kwargs_ref = { + BlockKwargs.sequence_first: False, + BlockKwargs.hidden_dims: (hidden_dim,), + } + + out_packed = kda_packed(packed, kwargs_packed) + # Run reference path separately per sequence without varlen packing, then concatenate. + ref_outs = [] + for b in range(batch_size): + out_batch = [] + length = lengths[b] + ref_seqs = torch.split(packed[b].unsqueeze(0), length.tolist(), dim=1) + for seq in ref_seqs: + kwargs_ref_seq = { + **kwargs_ref, + BlockKwargs.sequence_q_dim: TensorDim("sequence_q", seq.shape[1]), + } + out_batch.append(kda_ref(seq, kwargs_ref_seq)) + ref_outs.append(torch.cat(out_batch, dim=1)) + out_ref = torch.cat(ref_outs, dim=0) + out_ref_packed = out_ref + + assert out_packed.shape == packed.shape + assert out_ref_packed.shape == out_packed.shape + assert torch.allclose(out_packed, out_ref_packed, atol=1e-3, rtol=1e-3) + + out_packed.sum().backward() + out_ref_packed.sum().backward() + + assert _param_grad(kda_packed.q_proj.weight) is not None + assert _param_grad(kda_ref.q_proj.weight) is not None + assert torch.allclose( + _param_grad(kda_packed.q_proj.weight), _param_grad(kda_ref.q_proj.weight), atol=1e-3, rtol=1e-3 + ) + assert torch.allclose( + _param_grad(kda_packed.k_proj.weight), _param_grad(kda_ref.k_proj.weight), atol=1e-3, rtol=1e-3 + ) + assert torch.allclose( + _param_grad(kda_packed.v_proj.weight), _param_grad(kda_ref.v_proj.weight), atol=1e-3, rtol=1e-3 + ) + assert torch.allclose( + _param_grad(kda_packed.o_proj.weight), _param_grad(kda_ref.o_proj.weight), atol=1e-3, rtol=1e-3 + ) + + +if __name__ == "__main__": + pytest.main([__file__]) From 3b367d871b658818091469be1502d001ae033db2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 24 Nov 2025 20:54:00 +0000 Subject: [PATCH 09/43] gdn only: varlen test --- fast_llm/layers/ssm/config.py | 91 --------- fast_llm/layers/ssm/gdn.py | 91 +++++++-- fast_llm/layers/ssm/kda.py | 344 ---------------------------------- tests/test_ssm_varlen.py | 108 ++++++----- 4 files changed, 134 insertions(+), 500 deletions(-) delete mode 100644 fast_llm/layers/ssm/kda.py diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 2fa90aff9..6f36321ec 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -118,97 +118,6 @@ def _validate(self) -> None: super()._validate() -@config_class(dynamic_type={MixerConfig: "kda"}) -class KimiDeltaAttentionConfig(MixerConfig): - """ - Configuration for the KimiDeltaAttention mixer inspired by the Kimi Linear models. - """ - - _abstract = False - normalization: GatedRMSNormalizationConfig = Field( - desc="Configuration for the gated normalization applied to the KDA output.", - hint=FieldHint.architecture, - ) - q_projection_layer: AffineLinearConfig = Field( - desc="Projection that produces query vectors.", - hint=FieldHint.architecture, - ) - k_projection_layer: AffineLinearConfig = Field( - desc="Projection that produces key vectors.", - hint=FieldHint.architecture, - ) - v_projection_layer: AffineLinearConfig = Field( - desc="Projection that produces value vectors.", - hint=FieldHint.architecture, - ) - f_a_projection_layer: AffineLinearConfig = Field( - desc="Projection used for the Delta gating pre-activation.", - hint=FieldHint.architecture, - ) - f_b_projection_layer: AffineLinearConfig = Field( - desc="Projection used for the Delta gating expansion.", - hint=FieldHint.architecture, - ) - g_a_projection_layer: AffineLinearConfig = Field( - desc="Projection used for the output gating pre-activation.", - hint=FieldHint.architecture, - ) - g_b_projection_layer: AffineLinearConfig = Field( - desc="Projection used for the output gating expansion.", - hint=FieldHint.architecture, - ) - beta_projection_layer: AffineLinearConfig = Field( - desc="Projection that produces the Beta gate.", - hint=FieldHint.architecture, - ) - output_projection_layer: AffineLinearConfig = Field( - desc="Projection applied after the Delta recurrence and gated normalization.", - hint=FieldHint.architecture, - ) - convolution_layer: CausalConv1dConfig = Field( - desc="Depth-wise convolution applied independently on each Q, K and V stream.", - hint=FieldHint.architecture, - ) - dt_bias_weight: ParameterConfig = Field( - desc="Parameter configuration for the Delta gate bias.", - hint=FieldHint.architecture, - ) - a_log_weight: ParameterConfig = Field( - desc="Parameter configuration for the decay rates.", - hint=FieldHint.architecture, - ) - - heads: int = Field( - default=16, - desc="Number of attention heads.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - head_dim: int = Field( - default=64, - desc="Dimension of each head.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - - @property - def layer_class(self) -> "type": - from fast_llm.layers.ssm.kda import KimiDeltaAttention - - return KimiDeltaAttention - - def _validate(self) -> None: - with self._set_implicit_default(): - if "epsilon" not in self.normalization._explicit_fields: - self.normalization.epsilon = 1.0e-5 - if "activation" not in self.convolution_layer._explicit_fields: - self.convolution_layer.activation = "silu" - if "kernel_size" not in self.convolution_layer._explicit_fields: - self.convolution_layer.kernel_size = 4 - - super()._validate() - - @config_class() class SSMConfig(MixerConfig): # Layers diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index d07cb5e21..6fd86c6e5 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F +from einops import rearrange from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ @@ -11,7 +12,7 @@ from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias -from fast_llm.layers.ssm.config import GatedDeltaNetConfig +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, LinearAttentionKwargs from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import div @@ -263,28 +264,46 @@ def _forward( losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + - we flatten batch + seq + - forward as packed sequence, i.e. BS = 1, cu_seqlens and seq_idx created in the preprocessing step must reflect this (these are None if cross_document_attention is True) + - scatter results back to B x T x D + - note, if there are padding tokens they are note removed, they are assumed to be ignored later in the loss calculation and are assumed to be always ont he right + """ + sequence_first = kwargs[BlockKwargs.sequence_first] # in sequence parallel TP the input here is already scattered across sequence dimension # TODO: do we need masking of padding tokens? - # TODO: make sure varlen is supported + # TODO: fuse soome of the reshapes into rearranges hidden_states = input_ - # batch_size, sequence_length, _ = hidden_states.shape - projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs x seq_len x (qkvz) - projected_states_ba = self.in_proj_ba(hidden_states) # bs x seq_len x (b a) + # these are not + cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) + seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) + + projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs/seq x seq_len/bs x (qkvz) + projected_states_ba = self.in_proj_ba(hidden_states) # bs/seq x seq_len/bs x (b a) if sequence_first: projected_states_qkvz = projected_states_qkvz.transpose(0, 1) projected_states_ba = projected_states_ba.transpose(0, 1) + batch_size, sequence_length = projected_states_qkvz.shape[:2] + + # note: to support var len training (packing) we need to flatten hidden states to batch_size = 1 + # this is does not seem to be required by causal_conv1d_fn, but it it required by chunked_gdn_rule: https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/gated_delta_rule/chunk.py#L299 + # similarly to kimi linear and to SHortCOnv from fla, we pass it flattened tro conv_1d as well, i.e. see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914 query, key, value, z, beta, alpha = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) mixed_qkv = torch.cat((query, key, value), dim=-1) - mixed_qkv = mixed_qkv.transpose(1, 2) - mixed_qkv = self.convolution(mixed_qkv) - mixed_qkv = mixed_qkv.transpose(1, 2) + mixed_qkv = rearrange(mixed_qkv, "b s ... -> (b s) ...").unsqueeze(0) # 1 s d + mixed_qkv = rearrange(mixed_qkv, "b t d -> b d t") # mixed_qkv.transpose(1, 2) + mixed_qkv = self.convolution( + mixed_qkv, seq_idx=seq_idx + ) # conv func. gets sequence dim as last dim, see https://github.com/Dao-AILab/causal-conv1d/blob/22a4577d8ace9d5703daea91a7fb56695492152b/causal_conv1d/causal_conv1d_interface.py#L110 + mixed_qkv = rearrange(mixed_qkv, "b d t -> b t d") # mixed_qkv.transpose(1, 2) query, key, value = torch.split( mixed_qkv, ( @@ -301,6 +320,9 @@ def _forward( beta = beta.sigmoid() g = -torch.exp(self.A_log) * F.softplus(alpha + self.dt_bias) + beta = rearrange(beta, "b s ... -> (b s) ...").unsqueeze(0) + g = rearrange(g, "b s ... -> (b s) ...").unsqueeze(0) + if self._value_heads_per_key > 1: query = query.repeat_interleave(self._value_heads_per_key, dim=2) key = key.repeat_interleave(self._value_heads_per_key, dim=2) @@ -314,9 +336,12 @@ def _forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, ) z_shape_og = z.shape + core_attn_out = rearrange(core_attn_out.squeeze(0), "(b s) ... -> b s ...", b=batch_size, s=sequence_length) + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) core_attn_out = self.norm(core_attn_out, z) @@ -328,10 +353,50 @@ def _forward( return output + def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + """ + Creates seqlens and cu_seqlens for packed training (varlen). + This assumes that forward pass is performed on a fully packed sequence, i.e. where sequences are flattened out into BS = 1. + + Sets: + - seq_idx to [1, BS x T] tensor, where each elemnt is the sequence index of the corresponding token + - cu_seqlens to [N+1] tensor, where N is the total number of sequences in the batch, each element is the cumulative sequence length of packed sequences sofar + """ + + sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + if sequence_lengths is None: + raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") + seqlens = torch.cat(sequence_lengths) + cu_seqlens = torch.cat( + ( + torch.zeros(1, dtype=torch.long, device=batch.device), + torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), + ) + ) + # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 + # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 + kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens + # seq_idx has to be (bs, seqlen), but bs is forced to 1 + kwargs[LinearAttentionKwargs.seq_idx] = ( + ( + torch.cat( + [ + torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) + for n in (torch.diff(cu_seqlens).to(torch.int32)) + ], + dim=0, + ) + .eq(0) + .cumsum(0) + - 1 + ) + .to(torch.int32) + .unsqueeze(0) + ) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if LinearAttentionKwargs.sequence_lengths in kwargs: + self._preprocess_for_varlen(batch, kwargs) + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - # return ( - # self.in_proj_qkvz.get_compute_usage(input_, config) - # + self.in_proj_ba.get_compute_usage(input_, config) - # + self.out_proj.get_compute_usage(input_, config) - # ) raise NotImplementedError() diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py deleted file mode 100644 index b14fd4592..000000000 --- a/fast_llm/layers/ssm/kda.py +++ /dev/null @@ -1,344 +0,0 @@ -import logging -import typing - -import torch -from einops import rearrange, repeat - -from fast_llm.engine.base_model.config import ResourceUsageConfig -from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ -from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.decoder.block import BlockWithBias -from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta - -logger = logging.getLogger(__name__) - -try: - from fla.ops.kda import chunk_kda - from fla.ops.kda.gate import fused_kda_gate -except ImportError: - chunk_kda = None - fused_kda_gate = None - - -def index_first_axis(x, indices): - other_shape = x.shape[1:] - second_dim = other_shape.numel() - return torch.gather( - rearrange(x, "b ... -> b (...)"), - 0, - repeat(indices, "z -> z d", d=second_dim), - ).reshape(-1, *other_shape) - - -class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): - """ - Implementation of the Kimi Delta Attention mixer. - Reference Implementation: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/modeling_kimi.py - """ - - _config: ConfigType - - def __init__( - self, - config: ConfigType, - distributed_config: DistributedConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - return_bias: bool = True, - ): - super().__init__( - config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias - ) - if chunk_kda is None or fused_kda_gate is None: - raise ImportError( - "KimiDeltaAttention requires the `fla-core` package. " - "Please install it with `pip install -U fla-core`." - ) - - self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) - self._heads_dim = TensorDim( - "kda_heads", self._config.heads, self._parallel_dim if self._config.heads > 1 else None - ) - self._head_dim = TensorDim("kda_head_dim", self._config.head_dim) - self._projection_dim = CompositeTensorDim("kda_projection", (self._heads_dim, self._head_dim)) - self._local_heads = self._heads_dim.size - self._projection_size = self._projection_dim.size - - init = init_normal_(std=self._hidden_size**-0.5) - self.q_proj = self._config.q_projection_layer.get_layer( - hidden_dim, - self._projection_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.k_proj = self._config.k_projection_layer.get_layer( - hidden_dim, - self._projection_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.v_proj = self._config.v_projection_layer.get_layer( - hidden_dim, - self._projection_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - - self.q_conv = self._config.convolution_layer.get_layer( - self._projection_dim, - default_add_bias=False, - default_activation=self._config.convolution_layer.activation, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.k_conv = self._config.convolution_layer.get_layer( - self._projection_dim, - default_add_bias=False, - default_activation=self._config.convolution_layer.activation, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.v_conv = self._config.convolution_layer.get_layer( - self._projection_dim, - default_add_bias=False, - default_activation=self._config.convolution_layer.activation, - lr_scale=self._lr_scale, - peft=self._peft, - ) - - self.f_a_proj = self._config.f_a_projection_layer.get_layer( - hidden_dim, - self._head_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=False, # self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.f_b_proj = self._config.f_b_projection_layer.get_layer( - self._head_dim, - self._projection_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.g_a_proj = self._config.g_a_projection_layer.get_layer( - hidden_dim, - self._head_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=False, # self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.g_b_proj = self._config.g_b_projection_layer.get_layer( - self._head_dim, - self._projection_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - - self.beta_proj = self._config.beta_projection_layer.get_layer( - hidden_dim, - self._heads_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.o_proj = self._config.output_projection_layer.get_layer( - self._projection_dim, - hidden_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - - self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( - (self._projection_dim,), - default_initialization=init_ones_, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.A_log: ParameterMeta = self._config.a_log_weight.get_parameter( - (self._heads_dim,), - default_initialization=LambdaInitializer( - lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_() - ), - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.norm = self._config.normalization.get_layer( - self._head_dim, - lr_scale=self._lr_scale, - peft=self._peft, - ) - - def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module, seq_idx: torch.Tensor = None) -> torch.Tensor: - """ - Applies convolution. - Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one just uses causal_conv1d anyways. - Varlen: - - seq. idx are only suppored in channel last layout, i.e. no transpose - """ - tensor = rearrange(tensor, "b t d -> b d t") - # tensor = tensor.transpose(1, 2).contiguous() if seq_idx is None else tensor.transpose(1, 2) - tensor = conv(tensor, seq_idx=seq_idx) - return tensor.transpose(1, 2).contiguous() - - def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: - tensor = tensor.contiguous() - # since head_dim is the same vor k,q and v - # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) - return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) - - def _forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - # TODO: make sure varlen is supported - # TODO: do we need to deal with padding tokens? - sequence_first = kwargs[BlockKwargs.sequence_first] - hidden_states = input_ - - cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) - seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) - # TODO: can be made more efficeint by rearranging hidden states directly - residual_dtype = hidden_states.dtype - - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - if sequence_first: - # make bs first dim again - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - batch_size, sequence_length, _ = q.size() - - # work with bs = 1 to make sure varlen works correctly, only needed if micro batch size is > 1 - # can this be applied once to hidden state only? pr - q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) - k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) - v = rearrange(v, "b s ... -> (b s) ...").unsqueeze(0) - - # because we use cu_seqlens, chunk_kda requires batch size to be 1 (flatten, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303) - # similarly to ShortConvolution from fla we already operate on flattened batches here (https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914) - q = self._apply_conv(q, self.q_conv, seq_idx=seq_idx) - k = self._apply_conv(k, self.k_conv, seq_idx=seq_idx) - v = self._apply_conv(v, self.v_conv, seq_idx=seq_idx) - - g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) - if sequence_first: - g_kernel = g_kernel.transpose(0, 1) - g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) - - g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) - - beta = torch.sigmoid(self.beta_proj(hidden_states).float()) - q = self._reshape_heads(q) - k = self._reshape_heads(k) - v = self._reshape_heads(v) - if sequence_first: - beta = beta.transpose(0, 1) - beta = rearrange(beta, "b s h -> (b s) h").unsqueeze(0) - - # need to install nightly triton for this to work on H100, see https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md - # cu_seqlens requires batch ssize to be 1, i.e. flattened bacthes - attn_out, _ = chunk_kda( - q=q, - k=k, - v=v, - g=g_kernel, - beta=beta, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, - ) - - attn_out = attn_out.to(residual_dtype) - - g_out = self.g_b_proj(self.g_a_proj(hidden_states)) # bs x seq x n_local_heads x head dim - g_out = self._reshape_heads(g_out) - if sequence_first: - g_out = g_out.transpose(0, 1) - - attn_out = rearrange(attn_out.squeeze(0), "(b s) h d -> b s h d", b=batch_size, s=sequence_length) - attn_out = self.norm(attn_out, g_out) - attn_out = rearrange(attn_out, "b s h d -> b s (h d)") - if sequence_first: - attn_out = attn_out.transpose(0, 1) - attn_out = self.o_proj(attn_out) - - return attn_out - - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - raise NotImplementedError() - - def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] - if sequence_lengths is None: - raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") - sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] - seqlens = torch.cat(sequence_lengths) - cu_seqlens = torch.cat( - ( - torch.zeros(1, dtype=torch.long, device=batch.device), - torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), - ) - ) - # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 - # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 - kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens - # seq_idx has to be (bs, seqlen), but bs is forced to 1 - kwargs[LinearAttentionKwargs.seq_idx] = ( - ( - torch.cat( - [ - torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) - for n in (torch.diff(cu_seqlens).to(torch.int32)) - ], - dim=0, - ) - .eq(0) - .cumsum(0) - - 1 - ) - .to(torch.int32) - .unsqueeze(0) - ) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - if LinearAttentionKwargs.sequence_lengths in kwargs: - # TODO: packing is enabled by default, i.e. its always used? - # only get here when cross_document_attention is False - self._preprocess_for_varlen(batch, kwargs) diff --git a/tests/test_ssm_varlen.py b/tests/test_ssm_varlen.py index 9ca491e3d..1f7a83e65 100644 --- a/tests/test_ssm_varlen.py +++ b/tests/test_ssm_varlen.py @@ -8,8 +8,9 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.ssm import kda as kda_module -from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs +from fast_llm.layers.decoder.config import MixerConfig +from fast_llm.layers.ssm import gdn as gdn_module +from fast_llm.layers.ssm.config import GatedDeltaNetConfig # from mamba2 import NemotronHMamba2 @@ -71,9 +72,8 @@ def materialize_meta_tensors(model, tensor_space): return model -def unpack(packed_hidden_states, cu_seqlens): +def unpack_and_padd(packed_hidden_states, cu_seqlens, package_num): batch_size = packed_hidden_states.shape[0] - package_num = cu_seqlens.shape[0] - 1 seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() hidden_dim = packed_hidden_states.shape[2] hidden_states = torch.zeros( @@ -132,7 +132,7 @@ def generate_random_cu_seqlens(seq_len, packages_num=2): return cu_seqlens, index -def _materialize_kda_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: +def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: """ Materialize meta parameters on the requested device for KDA mixer layers. """ @@ -154,41 +154,46 @@ def _materialize_kda_tensors(module: torch.nn.Module, distributed: Distributed, def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: - # ParameterMeta stores grads in grad_buffer; fall back to .grad otherwise. return param.grad_buffer if hasattr(param, "grad_buffer") and param.grad_buffer is not None else param.grad +# TODO: include mamba varlen @pytest.mark.slow -@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA varlen needs CUDA") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Varlen test needs CUDA") @pytest.mark.skipif( - kda_module.chunk_kda is None or kda_module.fused_kda_gate is None, - reason="KDA fused kernels not available", + gdn_module.chunk_gated_delta_rule is None, + reason="Gated Delta Net fused kernels not available", ) -def test_kda_varlen_stacking_equivalence(distributed_config, distributed): +@pytest.mark.parametrize( + "config, sequence_first", + [ + pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), False), + pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), True), + # pytest.param(KimiDeltaAttentionConfig) + ], +) +def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: bool, distributed_config, distributed): """ - Check that KDA forward/backward match with and without stacking using the real kernels. + Check that Gated Delta Net forward/backward match with and without packing. """ device = torch.device("cuda") dtype = torch.float16 - heads, head_dim = 2, 16 - hidden_size = heads * head_dim - - config = KimiDeltaAttentionConfig(heads=heads, head_dim=head_dim) + hidden_size = 32 hidden_dim = TensorDim("hidden", hidden_size) - kda_packed = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - kda_ref = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - kda_packed.setup(distributed) - kda_ref.setup(distributed) - _materialize_kda_tensors(kda_packed, distributed, device) - _materialize_kda_tensors(kda_ref, distributed, device) - kda_ref.load_state_dict(kda_packed.state_dict()) - kda_packed.to(device=device, dtype=dtype) - kda_ref.to(device=device, dtype=dtype) + mixer_packed = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + mixer_ref = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + mixer_packed.setup(distributed) + mixer_ref.setup(distributed) + _materialize_mixer_tensors(mixer_packed, distributed, device) + _materialize_mixer_tensors(mixer_ref, distributed, device) + mixer_ref.load_state_dict(mixer_packed.state_dict()) + mixer_packed.to(device=device, dtype=dtype) + mixer_ref.to(device=device, dtype=dtype) batch_size = 2 # cu_seqlens path requires flattened batch seq_len = 15 - packages_num = torch.randint(2, 5, (1, batch_size))[0] # randomize packages num between 2 and 4 - lengths = [ + packages_num = torch.tensor([2, 3], device=device, dtype=torch.long) + sequence_lengths = [ torch.tensor( generate_random_cu_seqlens(seq_len, packages_num=packages_num[i].item())[0], device=device, @@ -196,63 +201,62 @@ def test_kda_varlen_stacking_equivalence(distributed_config, distributed): ).diff() for i in range(batch_size) ] + seqlens = torch.cat(sequence_lengths) + cu_seqlen = torch.cat( + ( + torch.zeros(1, dtype=torch.long, device=device), + torch.cumsum(seqlens, dim=0, dtype=torch.long).to(device), + ) + ) - # lengths = torch.tensor(cu_seqlens, device=device, dtype=torch.long)#.diff() - # total_tokens = lengths.sum().item() packed = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) + if sequence_first: + packed = packed.transpose(0, 1) kwargs_packed = { - LinearAttentionKwargs.sequence_lengths: lengths, - BlockKwargs.sequence_first: False, + BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.sequence_first: sequence_first, BlockKwargs.hidden_dims: (hidden_dim,), - # BlockKwargs.sequence_q_dim: TensorDim("sequence_q", lengths.sum().item()), } - # Use the layer's preprocess to construct cu_seqlens/seq_idx the same way as the implementation. - kda_packed.preprocess(packed, kwargs_packed) + mixer_packed.preprocess(packed, kwargs_packed) + assert torch.all(kwargs_packed["cu_seqlens"] == cu_seqlen) kwargs_ref = { BlockKwargs.sequence_first: False, BlockKwargs.hidden_dims: (hidden_dim,), } - out_packed = kda_packed(packed, kwargs_packed) + out_packed = mixer_packed(packed, kwargs_packed) + if sequence_first: + out_packed = out_packed.transpose(0, 1) # Run reference path separately per sequence without varlen packing, then concatenate. ref_outs = [] for b in range(batch_size): out_batch = [] - length = lengths[b] - ref_seqs = torch.split(packed[b].unsqueeze(0), length.tolist(), dim=1) + length = sequence_lengths[b] + if sequence_first: + ref_seqs = torch.split(packed[:, b].unsqueeze(0), length.tolist(), dim=1) + else: + ref_seqs = torch.split(packed[b].unsqueeze(0), length.tolist(), dim=1) for seq in ref_seqs: kwargs_ref_seq = { **kwargs_ref, BlockKwargs.sequence_q_dim: TensorDim("sequence_q", seq.shape[1]), } - out_batch.append(kda_ref(seq, kwargs_ref_seq)) + out_batch.append(mixer_ref(seq, kwargs_ref_seq)) ref_outs.append(torch.cat(out_batch, dim=1)) out_ref = torch.cat(ref_outs, dim=0) out_ref_packed = out_ref - assert out_packed.shape == packed.shape assert out_ref_packed.shape == out_packed.shape assert torch.allclose(out_packed, out_ref_packed, atol=1e-3, rtol=1e-3) out_packed.sum().backward() out_ref_packed.sum().backward() - assert _param_grad(kda_packed.q_proj.weight) is not None - assert _param_grad(kda_ref.q_proj.weight) is not None - assert torch.allclose( - _param_grad(kda_packed.q_proj.weight), _param_grad(kda_ref.q_proj.weight), atol=1e-3, rtol=1e-3 - ) - assert torch.allclose( - _param_grad(kda_packed.k_proj.weight), _param_grad(kda_ref.k_proj.weight), atol=1e-3, rtol=1e-3 - ) - assert torch.allclose( - _param_grad(kda_packed.v_proj.weight), _param_grad(kda_ref.v_proj.weight), atol=1e-3, rtol=1e-3 - ) - assert torch.allclose( - _param_grad(kda_packed.o_proj.weight), _param_grad(kda_ref.o_proj.weight), atol=1e-3, rtol=1e-3 - ) + for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): + if param.requires_grad: + assert torch.allclose(_param_grad(param), _param_grad(param_ref), atol=1e-3, rtol=1e-3) if __name__ == "__main__": From c48d4ee44ccf71fd66db298c3b6a316925412570 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 24 Nov 2025 21:02:42 +0000 Subject: [PATCH 10/43] clean up --- fast_llm/layers/ssm/gdn.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 6fd86c6e5..af770e381 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -267,13 +267,16 @@ def _forward( """ - we flatten batch + seq - forward as packed sequence, i.e. BS = 1, cu_seqlens and seq_idx created in the preprocessing step must reflect this (these are None if cross_document_attention is True) - - scatter results back to B x T x D - - note, if there are padding tokens they are note removed, they are assumed to be ignored later in the loss calculation and are assumed to be always ont he right + - scatter results back to B/T x T/B x D + - note, if there are padding tokens they are not treated in a special way here. + They are + - assumed to be ignored later in the loss calculation and + - are assumed to be always on the right and, hence, will be reflected in seq_idx and cu_seqlens (i.e. treated as a seperate packed sequence?) + - """ sequence_first = kwargs[BlockKwargs.sequence_first] # in sequence parallel TP the input here is already scattered across sequence dimension - # TODO: do we need masking of padding tokens? # TODO: fuse soome of the reshapes into rearranges hidden_states = input_ @@ -355,8 +358,9 @@ def _forward( def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ - Creates seqlens and cu_seqlens for packed training (varlen). + Creates seqlens and cu_seqlens for packed forward. This assumes that forward pass is performed on a fully packed sequence, i.e. where sequences are flattened out into BS = 1. + Note: padding tokens are always on the right and get their own entry in LinearAttentionKwargs.sequence_lengths --> they are treated as seperate sequence. Sets: - seq_idx to [1, BS x T] tensor, where each elemnt is the sequence index of the corresponding token From e2bb25cf49168a6ec43603691745ac2930168a62 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 14:35:37 +0000 Subject: [PATCH 11/43] test config --- tests/utils/model_configs.py | 39 ++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index f7797e3c8..7ee095d3f 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -817,6 +817,45 @@ def _update_and_add_testing_config( ) +_update_and_add_testing_config( + # Tests hybrid with gated delta net mixer. + "llama", + "hybrid_gdn", + updates={ + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "t": copy.deepcopy(_llama_block), + "gdn": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "gdn", + "value_heads": 4, + "key_heads": 2, + "key_head_dim": 16, + "value_head_dim": 16, + }, + }, + }, + "num_blocks": 2, + "pattern": ["t", "gdn"], + }, + }, + megatron_args=None, + checkpoint_format=AprielHybridSSMCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + }, + compare_factor=16, + skip_tests=("sdp", "ms", "stp"), +) + + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: models = request.config.getoption("--models") From d4f9b856f534cb7899b3dc5d887909c10c78cf7a Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 19:08:04 +0000 Subject: [PATCH 12/43] wip --- fast_llm/layers/ssm/gdn.py | 106 +++++++++++++------ setup.cfg | 4 +- tests/{test_ssm_varlen.py => test_varlen.py} | 20 ---- tests/utils/model_configs.py | 2 +- 4 files changed, 75 insertions(+), 57 deletions(-) rename tests/{test_ssm_varlen.py => test_varlen.py} (94%) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index af770e381..cba40e48c 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -210,9 +210,9 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - self.norm = self._config.normalization.get_layer( - self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft - ) + # self.norm = self._config.normalization.get_layer( + # self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft + # ) self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule @@ -221,41 +221,65 @@ def __init__( "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." ) + # def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + # """Derives query, key and value tensors from mixed_qkvz and mixed_ba.""" + # new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + # self._local_key_heads, + # 2 * self._config.key_head_dim + # + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, + # ) + # new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + # self._local_key_heads, + # 2 * self._local_value_heads // self._local_key_heads, + # ) + # mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + # mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + # split_arg_list_qkvz = [ + # self._config.key_head_dim, + # self._config.key_head_dim, + # (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + # (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + # ] + # split_arg_list_ba = [ + # self._local_value_heads // self._local_key_heads, + # self._local_value_heads // self._local_key_heads, + # ] + # query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) + # b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) + # # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + # value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) + # z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) + # b = b.reshape(b.size(0), b.size(1), self._local_value_heads) + # a = a.reshape(a.size(0), a.size(1), self._local_value_heads) + # return query, key, value, z, b, a + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): """ + note this must be the right way to split the TP, because TP splits each subdimention of ConcatenatedTensorDim("gdn_qkvz", (query_dim, key_dim, value_dim, z_dim)) seperately. Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. """ - new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( - self._local_key_heads, - 2 * self._config.key_head_dim - + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, + # Split contiguous q/k/v/z blocks and only then project them into per-head shapes. + local_qkv_sizes = ( + self._local_key_heads * self._config.key_head_dim, + self._local_key_heads * self._config.key_head_dim, + self._local_value_heads * self._config.value_head_dim, + self._local_value_heads * self._config.value_head_dim, ) - new_tensor_shape_ba = mixed_ba.size()[:-1] + ( - self._local_key_heads, - 2 * self._local_value_heads // self._local_key_heads, + query, key, value, z = torch.split(mixed_qkvz, local_qkv_sizes, dim=-1) + query = query.reshape(*query.shape[:-1], self._local_key_heads, self._config.key_head_dim) + key = key.reshape(*key.shape[:-1], self._local_key_heads, self._config.key_head_dim) + value = value.reshape(*value.shape[:-1], self._local_value_heads, self._config.value_head_dim) + z = z.reshape(*z.shape[:-1], self._local_value_heads, self._config.value_head_dim) + + beta, alpha = torch.split( + mixed_ba, + (self._local_value_heads, self._local_value_heads), + dim=-1, ) - - mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) - mixed_ba = mixed_ba.view(*new_tensor_shape_ba) - split_arg_list_qkvz = [ - self._config.key_head_dim, - self._config.key_head_dim, - (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), - (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), - ] - split_arg_list_ba = [ - self._local_value_heads // self._local_key_heads, - self._local_value_heads // self._local_key_heads, - ] - query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) - b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) - # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] - value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) - z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) - b = b.reshape(b.size(0), b.size(1), self._local_value_heads) - a = a.reshape(a.size(0), a.size(1), self._local_value_heads) - return query, key, value, z, b, a + beta = beta.reshape(*beta.shape[:-1], self._local_value_heads) + alpha = alpha.reshape(*alpha.shape[:-1], self._local_value_heads) + return query, key, value, z, beta, alpha def _forward( self, @@ -280,7 +304,6 @@ def _forward( # TODO: fuse soome of the reshapes into rearranges hidden_states = input_ - # these are not cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) @@ -321,7 +344,7 @@ def _forward( value = value.reshape(value.shape[0], value.shape[1], -1, self._config.value_head_dim) beta = beta.sigmoid() - g = -torch.exp(self.A_log) * F.softplus(alpha + self.dt_bias) + g = -self.A_log.float().exp() * F.softplus(alpha.float() + self.dt_bias) beta = rearrange(beta, "b s ... -> (b s) ...").unsqueeze(0) g = rearrange(g, "b s ... -> (b s) ...").unsqueeze(0) @@ -347,7 +370,7 @@ def _forward( core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) + # core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) if sequence_first: @@ -398,9 +421,24 @@ def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.A .unsqueeze(0) ) + def _preprocess_for_cross_doc_attetion(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + """ + Since forward is packed by default, this is needed for tests to path. + """ + if LinearAttentionKwargs.sequence_lengths in kwargs: + return self._preprocess_for_varlen(batch, kwargs) + bs, sequence_lengths = ( + batch.shape[:2] if not kwargs[BlockKwargs.sequence_first] else (batch.shape[1], batch.shape[0]) + ) + sequence_lengths = [torch.tensor([sequence_lengths] * bs, device=batch.device)] + kwargs[LinearAttentionKwargs.sequence_lengths] = sequence_lengths + self._preprocess_for_varlen(batch, kwargs) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: if LinearAttentionKwargs.sequence_lengths in kwargs: self._preprocess_for_varlen(batch, kwargs) + else: + self._preprocess_for_cross_doc_attetion(batch, kwargs) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() diff --git a/setup.cfg b/setup.cfg index 14e9dba28..77664cd0a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,8 +52,8 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d]>=2.2.4 - cartesia_pytorch>=0.0.2 + mamba_ssm[causal-conv1d]==2.2.4 + flash-linear-attention>=0.4.1 GENERATION = lm_eval>=0.4.9 diff --git a/tests/test_ssm_varlen.py b/tests/test_varlen.py similarity index 94% rename from tests/test_ssm_varlen.py rename to tests/test_varlen.py index 1f7a83e65..0b5a6caca 100644 --- a/tests/test_ssm_varlen.py +++ b/tests/test_varlen.py @@ -1,4 +1,3 @@ -import inspect import itertools import pytest @@ -12,25 +11,6 @@ from fast_llm.layers.ssm import gdn as gdn_module from fast_llm.layers.ssm.config import GatedDeltaNetConfig -# from mamba2 import NemotronHMamba2 - - -_mamba_varlen = False -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa - - _mamba_available = True - sig = inspect.signature(selective_scan_fn) - if "position_indices" in sig.parameters: - _mamba_varlen = True - else: - _mamba_varlen = False - # for training with packing install https://github.com/jxiw/varlen_mamba - # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md - -except (ImportError, RuntimeError): - _mamba_available = False - @pytest.fixture def distributed_config(): diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 7ee095d3f..ac431f7dc 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -831,7 +831,7 @@ def _update_and_add_testing_config( "mixer": { "type": "gdn", "value_heads": 4, - "key_heads": 2, + "key_heads": 4, "key_head_dim": 16, "value_head_dim": 16, }, From 8017a80d053585b6577566186db68a8798f5f418 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 20:57:03 +0000 Subject: [PATCH 13/43] gdn tests --- .../common/normalization/normalization.py | 12 +- fast_llm/layers/ssm/gdn.py | 9 +- tests/utils/model_configs.py | 113 ++---------------- 3 files changed, 22 insertions(+), 112 deletions(-) diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index ec8a52e26..ae46ee1dc 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -311,14 +311,14 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | super().__init__(config, hidden_dim, lr_scale) if rms_norm_gated is not None: - self._forward = self._forward_fused + self._forward_gated = self._forward_local else: - self._forward = self._forward + self._forward_gated = self._forward_local def forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self._normalized_shape), gate).view_as(input_) + return self._forward_gated(input_.view(-1, *self._normalized_shape), gate).view_as(input_) - def _forward_fused(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + def _forward_fla(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: return rms_norm_gated( input_, gate, @@ -331,6 +331,6 @@ def _forward_fused(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tens residual_in_fp32=False, ) - def _forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - normalized = self.rmsnorm(input_) + def _forward_local(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + normalized = self._forward(input_) return normalized * F.silu(gate) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index cba40e48c..cb3249b96 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -210,9 +210,9 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - # self.norm = self._config.normalization.get_layer( - # self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft - # ) + self.norm = self._config.normalization.get_layer( + self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft + ) self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule @@ -259,7 +259,6 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. """ - # Split contiguous q/k/v/z blocks and only then project them into per-head shapes. local_qkv_sizes = ( self._local_key_heads * self._config.key_head_dim, self._local_key_heads * self._config.key_head_dim, @@ -370,7 +369,7 @@ def _forward( core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) - # core_attn_out = self.norm(core_attn_out, z) + core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) if sequence_first: diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index ac431f7dc..31e84eec7 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -4,6 +4,7 @@ import functools import os import pathlib +import re import typing import pytest @@ -80,7 +81,7 @@ class ModelTestingConfig: groups: dict[ModelTestingGroup, ModelTestingGroupAction] # Scale the comparison thresholds for specific models. compare_factor: float = 1.0 - # Option to skip specific distributed configuration with name containing any of the provided strings. + # Option to skip specific distributed configuration with name matching any of the provided regex patterns. skip_tests: tuple[str] = () get_dataset: typing.Callable[[bool], tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path]] = ( get_model_test_dataset @@ -136,7 +137,7 @@ def base_model_config_class(self): return self.model_config_class.get_base_model_config_class() def should_skip(self, distributed_config: DistributedTestingConfig) -> bool: - return any(key in distributed_config.name for key in self.skip_tests) + return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests) def _update_and_add_testing_config( @@ -461,7 +462,7 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Arg update for cross-entropy splits doesn't work here. - skip_tests=("ce4", "ms"), + skip_tests=(r"ce4", r"ms"), ) _update_and_add_testing_config( @@ -594,7 +595,7 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Micro-sequence split not supported. - skip_tests=("sdp", "ms"), + skip_tests=(r"sdp", r"ms"), ) _update_and_add_testing_config( @@ -636,8 +637,8 @@ def _update_and_add_testing_config( compare_factor=2.0, # Micro-sequence split not supported. skip_tests=( - "sdp", - "ms", + r"sdp", + r"ms", ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) @@ -721,99 +722,7 @@ def _update_and_add_testing_config( }, compare_factor=6.0, # Micro-sequence split and sequence-first not supported. - # TODO: Gradient accumulation works but comparison is broken. - skip_tests=("sdp", "ms", "bf4", "df"), -) - - -_update_and_add_testing_config( - # Tests apriel2 format with pattern decoder mixing all mixer types. - # This comprehensive test exercises: attention, mamba, stochastic mixer, sliding window attention. - "llama", - "apriel2", - updates={ - ("model", "base_model", "tied_embedding_weight"): True, - ("model", "base_model", "decoder"): { - "type": "pattern", - "blocks": { - "attn_full": { - **copy.deepcopy(_llama_block), - "mixer": { - "type": "attention", - "rotary": {"type": "default", "theta": 10000}, - "heads": 8, - "head_groups": 4, - "head_size": 32, - "add_linear_biases": False, - }, - }, - "mamba": { - **copy.deepcopy(_llama_block), - "mixer": { - "type": "mamba_2", - "d_inner": 512, - "state_size": 16, - "dt_rank": 16, - "d_xb": 256, - "add_linear_biases": False, - }, - }, - "stochastic": { - **copy.deepcopy(_llama_block), - "mixer": { - "type": "stochastic", - "mixers": { - "attn": { - "type": "attention", - "rotary": {"type": "default", "theta": 10000}, - "heads": 8, - "head_groups": 4, - "head_size": 32, - "add_linear_biases": False, - }, - "mamba": { - "type": "mamba_2", - "d_inner": 512, - "state_size": 16, - "dt_rank": 16, - "d_xb": 256, - "add_linear_biases": False, - }, - }, - "sampling_strategy": "uniform", - "main_mixer_name": "attn", - }, - }, - "attn_swa": { - **copy.deepcopy(_llama_block), - "mixer": { - "type": "attention", - "rotary": {"type": "default", "theta": 10000}, - "heads": 8, - "head_groups": 4, - "head_size": 32, - "window_size": 128, - "add_linear_biases": False, - }, - }, - }, - "pattern": ["attn_full", "mamba", "stochastic", "attn_swa"], - "num_blocks": 4, - }, - }, - megatron_args=None, - checkpoint_format=Apriel2CheckpointFormat, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, - ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, - }, - compare_factor=2.0, - # Micro-sequence split not supported for Mamba. - skip_tests=("sdp", "ms"), + skip_tests=(r"sdp", r"ms", r"bf4", r"df"), ) @@ -851,8 +760,10 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, - compare_factor=16, - skip_tests=("sdp", "ms", "stp"), + compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla (passes with local non-fla norm) + # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). + # we should be using STP with this model! + skip_tests=(r"sdp", r"ms", r"^tp2$"), ) From 1e016014f88375151f80006ecbd9a093906b8aff Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 21:37:28 +0000 Subject: [PATCH 14/43] tests --- fast_llm/layers/common/normalization/normalization.py | 2 +- tests/utils/model_configs.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index ae46ee1dc..651d8e4b1 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -311,7 +311,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | super().__init__(config, hidden_dim, lr_scale) if rms_norm_gated is not None: - self._forward_gated = self._forward_local + self._forward_gated = self._forward_fla else: self._forward_gated = self._forward_local diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 31e84eec7..1eacb7840 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -760,9 +760,9 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, - compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla (passes with local non-fla norm) + compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). - # we should be using STP with this model! + # we should be using STP with this model, not TP! skip_tests=(r"sdp", r"ms", r"^tp2$"), ) From ca8cb5cb4514359b48ebacb6013906e01d1d14d0 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 22:19:29 +0000 Subject: [PATCH 15/43] tests --- fast_llm/layers/ssm/gdn.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index cb3249b96..3feac971f 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -378,7 +378,7 @@ def _forward( return output - def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: """ Creates seqlens and cu_seqlens for packed forward. This assumes that forward pass is performed on a fully packed sequence, i.e. where sequences are flattened out into BS = 1. @@ -390,15 +390,21 @@ def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.A """ sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + device = kwargs.get("device", None) if sequence_lengths is None: raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") - seqlens = torch.cat(sequence_lengths) - cu_seqlens = torch.cat( - ( - torch.zeros(1, dtype=torch.long, device=batch.device), - torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), - ) + seqlens = torch.tensor( + [ + 0, + *( + sequence_length + for sequence_lengths in kwargs[LinearAttentionKwargs.sequence_lengths] + for sequence_length in sequence_lengths + ), + ], + dtype=torch.int32, ) + cu_seqlens = seqlens.cumsum_(0).to(device) # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens @@ -433,11 +439,8 @@ def _preprocess_for_cross_doc_attetion(self, batch: torch.Tensor, kwargs: dict[s kwargs[LinearAttentionKwargs.sequence_lengths] = sequence_lengths self._preprocess_for_varlen(batch, kwargs) - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - if LinearAttentionKwargs.sequence_lengths in kwargs: - self._preprocess_for_varlen(batch, kwargs) - else: - self._preprocess_for_cross_doc_attetion(batch, kwargs) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + self._preprocess_for_varlen(kwargs) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() From 694d2877d05df28a4979295146af65c2335e4e37 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 22:27:37 +0000 Subject: [PATCH 16/43] nvm --- fast_llm/layers/ssm/gdn.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 3feac971f..d9db6ad4e 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -426,19 +426,6 @@ def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: .unsqueeze(0) ) - def _preprocess_for_cross_doc_attetion(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - """ - Since forward is packed by default, this is needed for tests to path. - """ - if LinearAttentionKwargs.sequence_lengths in kwargs: - return self._preprocess_for_varlen(batch, kwargs) - bs, sequence_lengths = ( - batch.shape[:2] if not kwargs[BlockKwargs.sequence_first] else (batch.shape[1], batch.shape[0]) - ) - sequence_lengths = [torch.tensor([sequence_lengths] * bs, device=batch.device)] - kwargs[LinearAttentionKwargs.sequence_lengths] = sequence_lengths - self._preprocess_for_varlen(batch, kwargs) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._preprocess_for_varlen(kwargs) From d3bd916f77ff47cfe107a8e2207dc5c73b9c4f42 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 13:43:50 +0000 Subject: [PATCH 17/43] requirements --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 77664cd0a..f4b2c904b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ HUGGINGFACE = # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 - flash-linear-attention>=0.4.1 + flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main GENERATION = lm_eval>=0.4.9 From 3ff7799979babf5d4cd2318cadd9e6d834739dde Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 13:51:38 +0000 Subject: [PATCH 18/43] wip --- fast_llm/layers/ssm/config.py | 91 ++++++ fast_llm/layers/ssm/kda.py | 346 +++++++++++++++++++++++ fast_llm/models/gpt/conversion/apriel.py | 130 ++++++++- 3 files changed, 566 insertions(+), 1 deletion(-) create mode 100644 fast_llm/layers/ssm/kda.py diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 6f36321ec..f9db4d350 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -118,6 +118,97 @@ def _validate(self) -> None: super()._validate() +@config_class(dynamic_type={MixerConfig: "kda"}) +class KimiDeltaAttentionConfig(MixerConfig): + """ + Configuration for the KimiDeltaAttention mixer inspired by the Kimi Linear models. + """ + + _abstract = False + normalization: GatedRMSNormalizationConfig = Field( + desc="Configuration for the gated normalization applied to the KDA output.", + hint=FieldHint.architecture, + ) + q_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces query vectors.", + hint=FieldHint.architecture, + ) + k_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces key vectors.", + hint=FieldHint.architecture, + ) + v_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces value vectors.", + hint=FieldHint.architecture, + ) + f_a_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the Delta gating pre-activation.", + hint=FieldHint.architecture, + ) + f_b_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the Delta gating expansion.", + hint=FieldHint.architecture, + ) + g_a_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the output gating pre-activation.", + hint=FieldHint.architecture, + ) + g_b_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the output gating expansion.", + hint=FieldHint.architecture, + ) + beta_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces the Beta gate.", + hint=FieldHint.architecture, + ) + output_projection_layer: AffineLinearConfig = Field( + desc="Projection applied after the Delta recurrence and gated normalization.", + hint=FieldHint.architecture, + ) + convolution_layer: CausalConv1dConfig = Field( + desc="Depth-wise convolution applied independently on each Q, K and V stream.", + hint=FieldHint.architecture, + ) + dt_bias_weight: ParameterConfig = Field( + desc="Parameter configuration for the Delta gate bias.", + hint=FieldHint.architecture, + ) + a_log_weight: ParameterConfig = Field( + desc="Parameter configuration for the decay rates.", + hint=FieldHint.architecture, + ) + + heads: int = Field( + default=16, + desc="Number of attention heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + head_dim: int = Field( + default=64, + desc="Dimension of each head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + + @property + def layer_class(self) -> "type": + from kda import KimiDeltaAttention + + return KimiDeltaAttention + + def _validate(self) -> None: + with self._set_implicit_default(): + if "epsilon" not in self.normalization._explicit_fields: + self.normalization.epsilon = 1.0e-5 + if "activation" not in self.convolution_layer._explicit_fields: + self.convolution_layer.activation = "silu" + if "kernel_size" not in self.convolution_layer._explicit_fields: + self.convolution_layer.kernel_size = 4 + + super()._validate() + + @config_class() class SSMConfig(MixerConfig): # Layers diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py new file mode 100644 index 000000000..03a682852 --- /dev/null +++ b/fast_llm/layers/ssm/kda.py @@ -0,0 +1,346 @@ +import logging +import typing + +import torch +from einops import rearrange, repeat + +from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta + +logger = logging.getLogger(__name__) + +try: + from fla.ops.kda import chunk_kda + from fla.ops.kda.gate import fused_kda_gate +except ImportError: + chunk_kda = None + fused_kda_gate = None + + +def index_first_axis(x, indices): + other_shape = x.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(x, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + +class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): + """ + Implementation of the Kimi Delta Attention mixer. + Reference Implementation: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/modeling_kimi.py + """ + + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias + ) + if chunk_kda is None or fused_kda_gate is None: + raise ImportError( + "KimiDeltaAttention requires the `fla-core` package. " + "Please install it with `pip install -U fla-core`." + ) + + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._heads_dim = TensorDim( + "kda_heads", self._config.heads, self._parallel_dim if self._config.heads > 1 else None + ) + self._head_dim = TensorDim("kda_head_dim", self._config.head_dim) + self._projection_dim = CompositeTensorDim("kda_projection", (self._heads_dim, self._head_dim)) + self._local_heads = self._heads_dim.size + self._projection_size = self._projection_dim.size + + init = init_normal_(std=self._hidden_size**-0.5) + self.q_proj = self._config.q_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.k_proj = self._config.k_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.v_proj = self._config.v_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.q_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.k_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.v_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.f_a_proj = self._config.f_a_projection_layer.get_layer( + hidden_dim, + self._head_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=False, # self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.f_b_proj = self._config.f_b_projection_layer.get_layer( + self._head_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.g_a_proj = self._config.g_a_projection_layer.get_layer( + hidden_dim, + self._head_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=False, # self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.g_b_proj = self._config.g_b_projection_layer.get_layer( + self._head_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.beta_proj = self._config.beta_projection_layer.get_layer( + hidden_dim, + self._heads_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.o_proj = self._config.output_projection_layer.get_layer( + self._projection_dim, + hidden_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( + (self._projection_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.A_log: ParameterMeta = self._config.a_log_weight.get_parameter( + (self._heads_dim,), + default_initialization=LambdaInitializer( + lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_() + ), + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.norm = self._config.normalization.get_layer( + self._head_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module, seq_idx: torch.Tensor = None) -> torch.Tensor: + """ + Applies convolution. + Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one just uses causal_conv1d anyways. + Varlen: + - seq. idx are only suppored in channel last layout, i.e. no transpose + """ + tensor = rearrange(tensor, "b t d -> b d t") + # tensor = tensor.transpose(1, 2).contiguous() if seq_idx is None else tensor.transpose(1, 2) + tensor = conv(tensor, seq_idx=seq_idx) + return tensor.transpose(1, 2).contiguous() + + def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: + tensor = tensor.contiguous() + # since head_dim is the same vor k,q and v + # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # TODO: make sure varlen is supported + # TODO: do we need to deal with padding tokens? + sequence_first = kwargs[BlockKwargs.sequence_first] + hidden_states = input_ + # padding_tokens = kwargs.get(LinearAttentionKwargs.padding_tokens, None) + + # if padding_tokens is not None: + + cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) + seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) + # TODO: can be made more efficeint by rearranging hidden states directly + residual_dtype = hidden_states.dtype + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if sequence_first: + # make bs first dim again + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + batch_size, sequence_length, _ = q.size() + + # work with bs = 1 to make sure varlen works correctly, only needed if micro batch size is > 1 + # can this be applied once to hidden state only? pr + q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) + k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) + v = rearrange(v, "b s ... -> (b s) ...").unsqueeze(0) + + # because we use cu_seqlens, chunk_kda requires batch size to be 1 (flatten, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303) + # similarly to ShortConvolution from fla we already operate on flattened batches here (https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914) + q = self._apply_conv(q, self.q_conv, seq_idx=seq_idx) + k = self._apply_conv(k, self.k_conv, seq_idx=seq_idx) + v = self._apply_conv(v, self.v_conv, seq_idx=seq_idx) + + g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) + if sequence_first: + g_kernel = g_kernel.transpose(0, 1) + g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) + + g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) + + beta = torch.sigmoid(self.beta_proj(hidden_states).float()) + q = self._reshape_heads(q) + k = self._reshape_heads(k) + v = self._reshape_heads(v) + if sequence_first: + beta = beta.transpose(0, 1) + beta = rearrange(beta, "b s h -> (b s) h").unsqueeze(0) + + # need to install nightly triton for this to work on H100, see https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md + # cu_seqlens requires batch ssize to be 1, i.e. flattened bacthes + attn_out, _ = chunk_kda( + q=q, + k=k, + v=v, + g=g_kernel, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + + attn_out = attn_out.to(residual_dtype) + + g_out = self.g_b_proj(self.g_a_proj(hidden_states)) # bs x seq x n_local_heads x head dim + g_out = self._reshape_heads(g_out) + if sequence_first: + g_out = g_out.transpose(0, 1) + + attn_out = rearrange(attn_out.squeeze(0), "(b s) h d -> b s h d", b=batch_size, s=sequence_length) + attn_out = self.norm(attn_out, g_out) + attn_out = rearrange(attn_out, "b s h d -> b s (h d)") + if sequence_first: + attn_out = attn_out.transpose(0, 1) + attn_out = self.o_proj(attn_out) + + return attn_out + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + raise NotImplementedError() + + def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + if sequence_lengths is None: + raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") + seqlens = torch.cat(sequence_lengths) + cu_seqlens = torch.cat( + ( + torch.zeros(1, dtype=torch.long, device=batch.device), + torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), + ) + ) + # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 + # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 + kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens + # seq_idx has to be (bs, seqlen), but bs is forced to 1 + kwargs[LinearAttentionKwargs.seq_idx] = ( + ( + torch.cat( + [ + torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) + for n in (torch.diff(cu_seqlens).to(torch.int32)) + ], + dim=0, + ) + .eq(0) + .cumsum(0) + - 1 + ) + .to(torch.int32) + .unsqueeze(0) + ) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if LinearAttentionKwargs.sequence_lengths in kwargs: + # TODO: packing is enabled by default, i.e. its always used? + # only get here when cross_document_attention is False + self._preprocess_for_varlen(batch, kwargs) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 215cc5257..e247009d0 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -268,7 +268,7 @@ def export_config(cls, config: GatedDeltaNetConfig) -> dict: "gdn_linear_conv_kernel_size": config.convolution_layer.kernel_size, }, } - + @classmethod def get_converters( cls, @@ -321,6 +321,128 @@ def get_converters( ] +class KimiDeltaAttentionConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "type": "kda", + "head_dim": config["linear_attn_config"]["head_dim"], + "heads": config["linear_attn_config"]["num_heads"], + "convolution_layer": { + "kernel_size": config["linear_attn_config"]["short_conv_kernel_size"], + }, + } + + @classmethod + def export_config(cls, config: KimiDeltaAttentionConfig) -> dict: + return { + "linear_attn_config": { + "head_dim": config.head_dim, + "num_heads": config.heads, + "short_conv_kernel_size": config.convolution_layer.kernel_size, + }, + } + + @classmethod + def get_converters( + cls, + config: KimiDeltaAttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.q_proj", + f"{hf_prefix}.q_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.k_proj", + f"{hf_prefix}.k_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.v_proj", + f"{hf_prefix}.v_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.q_conv", + f"{hf_prefix}.q_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.k_conv", + f"{hf_prefix}.k_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.v_conv", + f"{hf_prefix}.v_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.f_a_proj", + f"{hf_prefix}.f_a_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.f_b_proj", + f"{hf_prefix}.f_b_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.g_a_proj", + f"{hf_prefix}.g_a_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.g_b_proj", + f"{hf_prefix}.g_b_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.beta_proj", + f"{hf_prefix}.beta_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.o_proj", + f"{hf_prefix}.o_proj", + False, + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.A_log", + f"{hf_prefix}.A_log", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.dt_bias", + f"{hf_prefix}.dt_bias", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.norm", + f"{hf_prefix}.norm", + False, + drop_on_export=drop_on_export, + ), + ] + + class AprielBlockConverterBase(MistralBlockConverter): mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter @@ -330,6 +452,11 @@ class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): hf_mixer_name: typing.ClassVar[str] = "mixer" +class AprielKimiDeltaAttentionBlockConverter(AprielBlockConverterBase): + mixer_converter_class: typing.ClassVar[type[KimiDeltaAttentionConverter]] = KimiDeltaAttentionConverter + hf_mixer_name: typing.ClassVar[str] = "mixer" + + class AprielMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" @@ -351,6 +478,7 @@ class AprielBlockConverter: _converter_classes = { AttentionConfig: MistralBlockConverter, Mamba2Config: AprielMamba2BlockConverter, + KimiDeltaAttentionConfig: AprielKimiDeltaAttentionBlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, GatedDeltaNetConfig: AprielGatedDeltaNetBlockConverter, } From 9a53c5b93b5e7cd85fa94f644425ac216084e632 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 13:53:05 +0000 Subject: [PATCH 19/43] clean up --- fast_llm/models/gpt/conversion/apriel.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 215cc5257..d9ddf57d1 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -9,12 +9,7 @@ from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig -from fast_llm.layers.ssm.config import ( - DiscreteMamba2Config, - GatedDeltaNetConfig, - KimiDeltaAttentionConfig, - Mamba2Config, -) +from fast_llm.layers.ssm.config import DiscreteMamba2Config, GatedDeltaNetConfig, Mamba2Config from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters @@ -268,11 +263,11 @@ def export_config(cls, config: GatedDeltaNetConfig) -> dict: "gdn_linear_conv_kernel_size": config.convolution_layer.kernel_size, }, } - + @classmethod def get_converters( cls, - config: KimiDeltaAttentionConfig, + config: GatedDeltaNetConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, @@ -345,7 +340,6 @@ class AprielBlockConverter: AttentionConfig: "t", Mamba2Config: "m2", DiscreteMamba2Config: "m2d", - KimiDeltaAttentionConfig: "kda", GatedDeltaNetConfig: "gdn", } _converter_classes = { From 80041ce7f7c40a8b6d8382370999df3ed5039c39 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 13:59:29 +0000 Subject: [PATCH 20/43] conversion --- fast_llm/models/gpt/conversion/apriel.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index d9ddf57d1..41c444df1 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,7 +8,6 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.ssm.config import DiscreteMamba2Config, GatedDeltaNetConfig, Mamba2Config from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat @@ -225,19 +224,6 @@ def get_converters( ] -class AprielMLPConverter(LlamaMLPConverter): - @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - out = super().export_config(config) - del out["mlp_bias"] - return out - - class GatedDeltaNetConverter: @classmethod def import_config(cls, config: dict) -> dict: @@ -316,11 +302,7 @@ def get_converters( ] -class AprielBlockConverterBase(MistralBlockConverter): - mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter - - -class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): +class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" @@ -330,7 +312,7 @@ class AprielMamba2BlockConverter(MistralBlockConverter): hf_mixer_name: typing.ClassVar[str] = "mixer" -class AprielGatedDeltaNetBlockConverter(AprielBlockConverterBase): +class AprielGatedDeltaNetBlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[GatedDeltaNetConverter]] = GatedDeltaNetConverter hf_mixer_name: typing.ClassVar[str] = "mixer" From d6677b08f0baf5a79f25e9adbfba8dc111631cb1 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 16:16:35 +0000 Subject: [PATCH 21/43] comments on the layour + HF forward equivalence test --- fast_llm/layers/ssm/gdn.py | 85 ++++++++++------- tests/layers/test_gdn_equivalence.py | 134 +++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 32 deletions(-) create mode 100644 tests/layers/test_gdn_equivalence.py diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index d9db6ad4e..9f3a55263 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -117,6 +117,9 @@ class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]): - For tensor parallel implementtion (no sequnece prallel): we scatter teh heads accross ranks. - Sequence Tensor parallel: in_proj_qkvz all reduces across sequence dim. --> each rank performs work on full sequence but only a subset of heads (standrd TP). + Note, Qwen3_Next follows a different layout, where gdn_qkvz is assumed to be layed out as [h0: Q,K,V,Z][h1: Q,K,V,Z][h2: Q,K,V,Z] + Here we follow a more natural layout for gdn_qkvz: [Q_all_heads | K_all_heads | V_all_heads | Z_all_heads]. If we want to apply MIL init here it should be easier like this. + """ _config: ConfigType @@ -131,6 +134,7 @@ def __init__( peft: PeftConfig | None, return_bias: bool = True, ): + super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) @@ -141,6 +145,7 @@ def __init__( self._key_heads_dim = TensorDim( "gdn_key_heads", self._config.key_heads, self._parallel_dim if self._config.key_heads > 1 else None ) + self._value_head_dim = TensorDim("gdn_value_head_dim", self._config.value_head_dim) self._key_head_dim = TensorDim("gdn_key_head_dim", self._config.key_head_dim) self._local_value_heads = self._value_heads_dim.size @@ -150,8 +155,18 @@ def __init__( query_dim = CompositeTensorDim("gdn_query", (self._key_heads_dim, self._key_head_dim)) key_dim = CompositeTensorDim("gdn_key", (self._key_heads_dim, self._key_head_dim)) value_dim = CompositeTensorDim("gdn_value", (self._value_heads_dim, self._value_head_dim)) + z_dim = CompositeTensorDim("gdn_z", (self._value_heads_dim, self._value_head_dim)) qkvz_dim = ConcatenatedTensorDim("gdn_qkvz", (query_dim, key_dim, value_dim, z_dim)) + # for Qwen's layour use soemthing like this instead: + # n_vheads_per_k_head = self._config.value_heads // self._config.key_heads + # head_size = 2 * self._config.key_head_dim + 2 * self._config.value_head_dim * n_vheads_per_k_head + # n_heads = self._config.key_heads + # qkvz_dim = TensorDim(e + # "gdn_qkvz", + # n_heads * head_size, + # self._parallel_dim if n_heads > 1 else None, + # ) ba_dim = ConcatenatedTensorDim( "gdn_ba", ( @@ -159,6 +174,12 @@ def __init__( CompositeTensorDim("gdn_alpha", (self._value_heads_dim,)), ), ) + # for Qwen's layour use something like this instead: + # ba_dim = TensorDim( + # "gdn_ba", + # 2 * self._config.value_heads, + # self._parallel_dim if 2 * self._config.value_heads > 1 else None, + # ) qkv_channels_dim = ConcatenatedTensorDim("gdn_qkv", (query_dim, key_dim, value_dim)) @@ -221,42 +242,42 @@ def __init__( "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." ) - # def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): - # """Derives query, key and value tensors from mixed_qkvz and mixed_ba.""" - # new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( - # self._local_key_heads, - # 2 * self._config.key_head_dim - # + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, - # ) - # new_tensor_shape_ba = mixed_ba.size()[:-1] + ( - # self._local_key_heads, - # 2 * self._local_value_heads // self._local_key_heads, - # ) - # mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) - # mixed_ba = mixed_ba.view(*new_tensor_shape_ba) - # split_arg_list_qkvz = [ - # self._config.key_head_dim, - # self._config.key_head_dim, - # (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), - # (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), - # ] - # split_arg_list_ba = [ - # self._local_value_heads // self._local_key_heads, - # self._local_value_heads // self._local_key_heads, - # ] - # query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) - # b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) - # # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] - # value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) - # z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) - # b = b.reshape(b.size(0), b.size(1), self._local_value_heads) - # a = a.reshape(a.size(0), a.size(1), self._local_value_heads) - # return query, key, value, z, b, a + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + """Derives query, key and value tensors from mixed_qkvz and mixed_ba.""" + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self._local_key_heads, + 2 * self._config.key_head_dim + + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + self._local_key_heads, + 2 * self._local_value_heads // self._local_key_heads, + ) + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + split_arg_list_qkvz = [ + self._config.key_head_dim, + self._config.key_head_dim, + (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + ] + split_arg_list_ba = [ + self._local_value_heads // self._local_key_heads, + self._local_value_heads // self._local_key_heads, + ] + query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) + b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) + z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) + b = b.reshape(b.size(0), b.size(1), self._local_value_heads) + a = a.reshape(a.size(0), a.size(1), self._local_value_heads) + return query, key, value, z, b, a def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): """ - note this must be the right way to split the TP, because TP splits each subdimention of ConcatenatedTensorDim("gdn_qkvz", (query_dim, key_dim, value_dim, z_dim)) seperately. Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. + Replaces fix_query_key_value_ordering from Qwen due to layout differences. """ local_qkv_sizes = ( diff --git a/tests/layers/test_gdn_equivalence.py b/tests/layers/test_gdn_equivalence.py new file mode 100644 index 000000000..9886056ea --- /dev/null +++ b/tests/layers/test_gdn_equivalence.py @@ -0,0 +1,134 @@ +import pytest +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.ssm.config import GatedDeltaNetConfig + +try: + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextConfig, Qwen3NextGatedDeltaNet +except ImportError: + Qwen3NextConfig, Qwen3NextGatedDeltaNet = None, None + + +def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: + """ + Instantiate meta-allocated parameters on the requested device so the layer can run standalone. + """ + for name, param in module.named_parameters(): + if param.device.type != "meta": + continue + param_data = torch.empty_like(param, device=device) + param.init_parameter(param_data, distributed) + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + target = module + if module_path is not None: + for part in module_path.split("."): + target = getattr(target, part) + new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + new_param.grad = None + new_param.grad_buffer = torch.zeros_like(param_data) + new_param.param_grad_is_zero = True + target._parameters[param_name] = new_param + + +@pytest.mark.slow +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Varlen test needs CUDA") +@pytest.mark.skipif(Qwen3NextConfig is None, reason="transformers with Qwen3-Next not installed") +def test_fast_llm_gdn_matches_qwen3_next_forward(): + torch.manual_seed(0) + device = torch.device("cuda") + dtype = torch.bfloat16 + + hidden_size = 16 + seq_len = 6 + num_k_heads = 2 + num_v_heads = 4 + head_k_dim = 4 + head_v_dim = 4 + kernel_size = 4 + + hf_config = Qwen3NextConfig( + hidden_size=hidden_size, + linear_num_key_heads=num_k_heads, + linear_num_value_heads=num_v_heads, + linear_key_head_dim=head_k_dim, + linear_value_head_dim=head_v_dim, + linear_conv_kernel_dim=kernel_size, + hidden_act="silu", + rms_norm_eps=1e-6, + dtype=dtype, + ) + hf_layer = Qwen3NextGatedDeltaNet(hf_config, layer_idx=0).to(device=device, dtype=dtype).eval() + + fast_config = GatedDeltaNetConfig( + value_heads=num_v_heads, + key_heads=num_k_heads, + value_head_dim=head_v_dim, + key_head_dim=head_k_dim, + activation=ActivationType.silu, + normalization={"epsilon": hf_config.rms_norm_eps}, + convolution_layer={"kernel_size": kernel_size, "activation": ActivationType.silu}, + ) + distributed_config = DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + hidden_dim = TensorDim("hidden", hidden_size) + fast_layer = fast_config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + distributed = Distributed(config=distributed_config) + fast_layer.setup(distributed) + _materialize_mixer_tensors(fast_layer, distributed, device) + fast_layer.to(device=device, dtype=dtype).eval() + + with torch.no_grad(): + fast_layer.in_proj_qkvz.weight.copy_(hf_layer.in_proj_qkvz.weight) + fast_layer.in_proj_ba.weight.copy_(hf_layer.in_proj_ba.weight) + fast_layer.convolution.weight.copy_(hf_layer.conv1d.weight) + if fast_layer.convolution.bias is not None and hf_layer.conv1d.bias is not None: + fast_layer.convolution.bias.copy_(hf_layer.conv1d.bias) + fast_layer.out_proj.weight.copy_(hf_layer.out_proj.weight) + fast_layer.A_log.copy_(hf_layer.A_log) + fast_layer.dt_bias.copy_(hf_layer.dt_bias) + fast_layer.norm.weight.copy_(hf_layer.norm.weight) + + hidden_states = torch.randn(1, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=False) + + param_map = { + "in_proj_qkvz.weight": "in_proj_qkvz.weight", + "in_proj_ba.weight": "in_proj_ba.weight", + "convolution.weight": "conv1d.weight", + "convolution.bias": "conv1d.bias", + "out_proj.weight": "out_proj.weight", + "A_log": "A_log", + "dt_bias": "dt_bias", + "norm.weight": "norm.weight", + } + for k, p in fast_layer.state_dict().items(): + torch.testing.assert_close(p, hf_layer.state_dict()[param_map[k]], atol=1e-6, rtol=1e-6) + + # need to monkey patch the hf implementation with our fix_query_key_value_ordering due to the layout differences + hf_layer.fix_query_key_value_ordering = fast_layer.fix_query_key_value_ordering + hf_layer._local_key_heads = fast_layer._local_key_heads + hf_layer._local_value_heads = fast_layer._local_value_heads + hf_layer._config = fast_layer._config + + hf_out = hf_layer(hidden_states) + + fast_kwargs = { + BlockKwargs.sequence_first: False, + BlockKwargs.hidden_dims: (hidden_dim,), + } + fast_out = fast_layer(hidden_states, fast_kwargs) + + torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__]) From 5d3b6d07e9808072293c6b90c7b811649b65ea71 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 18:27:19 +0000 Subject: [PATCH 22/43] wip --- fast_llm/layers/ssm/kda.py | 29 ++++++++++++++------------ tests/utils/model_configs.py | 40 +++++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 03a682852..ff3ed99e8 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -264,7 +264,7 @@ def _forward( g_kernel = g_kernel.transpose(0, 1) g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) - g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) + g_kernel = fused_kda_gate(g_kernel, self.A_log, dt_bias=self.dt_bias) beta = torch.sigmoid(self.beta_proj(hidden_states).float()) q = self._reshape_heads(q) @@ -307,17 +307,23 @@ def _forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() - def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + device = kwargs.get("device", None) if sequence_lengths is None: raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") - seqlens = torch.cat(sequence_lengths) - cu_seqlens = torch.cat( - ( - torch.zeros(1, dtype=torch.long, device=batch.device), - torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), - ) + seqlens = torch.tensor( + [ + 0, + *( + sequence_length + for sequence_lengths in kwargs[LinearAttentionKwargs.sequence_lengths] + for sequence_length in sequence_lengths + ), + ], + dtype=torch.int32, ) + cu_seqlens = seqlens.cumsum_(0).to(device) # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens @@ -339,8 +345,5 @@ def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.A .unsqueeze(0) ) - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - if LinearAttentionKwargs.sequence_lengths in kwargs: - # TODO: packing is enabled by default, i.e. its always used? - # only get here when cross_document_attention is False - self._preprocess_for_varlen(batch, kwargs) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + self._preprocess_for_varlen(kwargs) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1eacb7840..843c51e2a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -14,7 +14,6 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.conversion.config import ( - Apriel2CheckpointFormat, AprielHybridSSMCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, @@ -767,6 +766,45 @@ def _update_and_add_testing_config( ) +_update_and_add_testing_config( + # Tests hybrid with gated delta net mixer. + "llama", + "hybrid_kda", + updates={ + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "t": copy.deepcopy(_llama_block), + "kda": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "kda", + "heads": 4, + "head_dim": 16, + }, + }, + }, + "num_blocks": 2, + "pattern": ["t", "kda"], + }, + }, + megatron_args=None, + checkpoint_format=AprielHybridSSMCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + }, + compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla + # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). + # we should be using STP with this model, not TP! + skip_tests=(r"sdp", r"ms", r"^tp2$"), +) + + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: models = request.config.getoption("--models") From 0d41dce0729d14fb2a99b870bb275ace34b0a824 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 20:12:38 +0000 Subject: [PATCH 23/43] wip --- fast_llm/layers/ssm/kda.py | 20 +++++------ tests/test_varlen.py | 69 +++++++++++++++++------------------- tests/utils/model_configs.py | 2 +- 3 files changed, 42 insertions(+), 49 deletions(-) diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index ff3ed99e8..e444df908 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -222,17 +222,15 @@ def _forward( losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - # TODO: make sure varlen is supported - # TODO: do we need to deal with padding tokens? + """ + Same as in gdn, the idea is to always do forward pass in a packed way, whcih is required for varlen support. + """ sequence_first = kwargs[BlockKwargs.sequence_first] hidden_states = input_ - # padding_tokens = kwargs.get(LinearAttentionKwargs.padding_tokens, None) - - # if padding_tokens is not None: cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) - # TODO: can be made more efficeint by rearranging hidden states directly + # TODO: can be made more efficeint by rearranging hidden states directly and only once residual_dtype = hidden_states.dtype q = self.q_proj(hidden_states) @@ -240,15 +238,11 @@ def _forward( v = self.v_proj(hidden_states) if sequence_first: - # make bs first dim again q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) batch_size, sequence_length, _ = q.size() - - # work with bs = 1 to make sure varlen works correctly, only needed if micro batch size is > 1 - # can this be applied once to hidden state only? pr q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) v = rearrange(v, "b s ... -> (b s) ...").unsqueeze(0) @@ -263,8 +257,9 @@ def _forward( if sequence_first: g_kernel = g_kernel.transpose(0, 1) g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) + g_kernel = self._reshape_heads(g_kernel) - g_kernel = fused_kda_gate(g_kernel, self.A_log, dt_bias=self.dt_bias) + g_kernel = fused_kda_gate(g_kernel, self.A_log.float(), dt_bias=self.dt_bias) beta = torch.sigmoid(self.beta_proj(hidden_states).float()) q = self._reshape_heads(q) @@ -312,13 +307,14 @@ def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: device = kwargs.get("device", None) if sequence_lengths is None: raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") + seqlens = torch.tensor( [ 0, *( sequence_length for sequence_lengths in kwargs[LinearAttentionKwargs.sequence_lengths] - for sequence_length in sequence_lengths + for sequence_length in sequence_lengths # bs ), ], dtype=torch.int32, diff --git a/tests/test_varlen.py b/tests/test_varlen.py index 0b5a6caca..e6ad61ecf 100644 --- a/tests/test_varlen.py +++ b/tests/test_varlen.py @@ -1,5 +1,3 @@ -import itertools - import pytest import torch @@ -84,7 +82,7 @@ def pack(hidden_states, cu_seqlens, batch_size): return packed_hidden_states -def generate_random_cu_seqlens(seq_len, packages_num=2): +def generate_random_seq_len(seq_len, packages_num=2): if packages_num < 1: raise ValueError("packages_num must be at least 1") @@ -92,24 +90,27 @@ def generate_random_cu_seqlens(seq_len, packages_num=2): base, rem = divmod(seq_len, packages_num) # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] lengths = [base + 1 if i < rem else base for i in range(packages_num)] + assert sum(lengths) == seq_len + assert len(lengths) == packages_num + return lengths - # split points exclude the final cumulative (seq_len) - split_points = list(itertools.accumulate(lengths))[:-1] + # # split points exclude the final cumulative (seq_len) + # split_points = list(itertools.accumulate(lengths))[:-1] - cu_seqlens = [0] + split_points + [seq_len] - # cu_seqlens = split_points # + [seq_len] + # cu_seqlens = split_points + [seq_len] + # # cu_seqlens = split_points # + [seq_len] - # index: for each chunk, we emit 0,1,...,length-1 - index = [] - for length in lengths: - index.extend(range(length)) + # # index: for each chunk, we emit 0,1,...,length-1 + # index = [] + # for length in lengths: + # index.extend(range(length)) - # sanity check - assert len(cu_seqlens) - 1 == packages_num - assert sum(lengths) == seq_len - assert len(index) == seq_len + # # sanity check + # assert len(cu_seqlens) == packages_num + # assert sum(lengths) == seq_len + # assert len(index) == seq_len - return cu_seqlens, index + # return cu_seqlens, index def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: @@ -149,7 +150,8 @@ def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: [ pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), False), pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), True), - # pytest.param(KimiDeltaAttentionConfig) + # pytest.param(KimiDeltaAttentionConfig(heads=4, head_dim=16), False), + # pytest.param(KimiDeltaAttentionConfig(heads=4, head_dim=16), True), ], ) def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: bool, distributed_config, distributed): @@ -174,34 +176,23 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: seq_len = 15 packages_num = torch.tensor([2, 3], device=device, dtype=torch.long) sequence_lengths = [ - torch.tensor( - generate_random_cu_seqlens(seq_len, packages_num=packages_num[i].item())[0], - device=device, - dtype=torch.long, - ).diff() - for i in range(batch_size) + generate_random_seq_len(seq_len, packages_num=packages_num[i].item()) for i in range(batch_size) ] - seqlens = torch.cat(sequence_lengths) - cu_seqlen = torch.cat( - ( - torch.zeros(1, dtype=torch.long, device=device), - torch.cumsum(seqlens, dim=0, dtype=torch.long).to(device), - ) - ) packed = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) if sequence_first: packed = packed.transpose(0, 1) kwargs_packed = { + BlockKwargs.device: device, BlockKwargs.sequence_lengths: sequence_lengths, BlockKwargs.sequence_first: sequence_first, BlockKwargs.hidden_dims: (hidden_dim,), } - mixer_packed.preprocess(packed, kwargs_packed) - assert torch.all(kwargs_packed["cu_seqlens"] == cu_seqlen) + mixer_packed.preprocess(kwargs_packed) kwargs_ref = { + BlockKwargs.device: device, BlockKwargs.sequence_first: False, BlockKwargs.hidden_dims: (hidden_dim,), } @@ -215,13 +206,13 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: out_batch = [] length = sequence_lengths[b] if sequence_first: - ref_seqs = torch.split(packed[:, b].unsqueeze(0), length.tolist(), dim=1) + ref_seqs = torch.split(packed[:, b].unsqueeze(0), length, dim=1) else: - ref_seqs = torch.split(packed[b].unsqueeze(0), length.tolist(), dim=1) + ref_seqs = torch.split(packed[b].unsqueeze(0), length, dim=1) for seq in ref_seqs: kwargs_ref_seq = { **kwargs_ref, - BlockKwargs.sequence_q_dim: TensorDim("sequence_q", seq.shape[1]), + BlockKwargs.sequence_lengths: [seq.shape[1]], } out_batch.append(mixer_ref(seq, kwargs_ref_seq)) ref_outs.append(torch.cat(out_batch, dim=1)) @@ -236,7 +227,13 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): if param.requires_grad: - assert torch.allclose(_param_grad(param), _param_grad(param_ref), atol=1e-3, rtol=1e-3) + torch.testing.assert_close( + _param_grad(param), + _param_grad(param_ref), + atol=1e-3, + rtol=1e-3, + msg=f"Grad mismatch for parameter {name}", + ) if __name__ == "__main__": diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 843c51e2a..70233705a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -798,7 +798,7 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, - compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla + compare_factor=10.0, # similar to gdn with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! skip_tests=(r"sdp", r"ms", r"^tp2$"), From 6e2c1fe90406f1f6a64ceae1f1ed87d4201151dc Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 20:13:48 +0000 Subject: [PATCH 24/43] varlen test --- tests/test_varlen.py | 57 +++++++++++++------------------------------- 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/tests/test_varlen.py b/tests/test_varlen.py index 0b5a6caca..126a3e1e5 100644 --- a/tests/test_varlen.py +++ b/tests/test_varlen.py @@ -1,5 +1,3 @@ -import itertools - import pytest import torch @@ -84,7 +82,7 @@ def pack(hidden_states, cu_seqlens, batch_size): return packed_hidden_states -def generate_random_cu_seqlens(seq_len, packages_num=2): +def generate_random_seq_len(seq_len, packages_num=2): if packages_num < 1: raise ValueError("packages_num must be at least 1") @@ -92,24 +90,9 @@ def generate_random_cu_seqlens(seq_len, packages_num=2): base, rem = divmod(seq_len, packages_num) # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] lengths = [base + 1 if i < rem else base for i in range(packages_num)] - - # split points exclude the final cumulative (seq_len) - split_points = list(itertools.accumulate(lengths))[:-1] - - cu_seqlens = [0] + split_points + [seq_len] - # cu_seqlens = split_points # + [seq_len] - - # index: for each chunk, we emit 0,1,...,length-1 - index = [] - for length in lengths: - index.extend(range(length)) - - # sanity check - assert len(cu_seqlens) - 1 == packages_num assert sum(lengths) == seq_len - assert len(index) == seq_len - - return cu_seqlens, index + assert len(lengths) == packages_num + return lengths def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: @@ -149,7 +132,6 @@ def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: [ pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), False), pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), True), - # pytest.param(KimiDeltaAttentionConfig) ], ) def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: bool, distributed_config, distributed): @@ -174,34 +156,23 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: seq_len = 15 packages_num = torch.tensor([2, 3], device=device, dtype=torch.long) sequence_lengths = [ - torch.tensor( - generate_random_cu_seqlens(seq_len, packages_num=packages_num[i].item())[0], - device=device, - dtype=torch.long, - ).diff() - for i in range(batch_size) + generate_random_seq_len(seq_len, packages_num=packages_num[i].item()) for i in range(batch_size) ] - seqlens = torch.cat(sequence_lengths) - cu_seqlen = torch.cat( - ( - torch.zeros(1, dtype=torch.long, device=device), - torch.cumsum(seqlens, dim=0, dtype=torch.long).to(device), - ) - ) packed = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) if sequence_first: packed = packed.transpose(0, 1) kwargs_packed = { + BlockKwargs.device: device, BlockKwargs.sequence_lengths: sequence_lengths, BlockKwargs.sequence_first: sequence_first, BlockKwargs.hidden_dims: (hidden_dim,), } - mixer_packed.preprocess(packed, kwargs_packed) - assert torch.all(kwargs_packed["cu_seqlens"] == cu_seqlen) + mixer_packed.preprocess(kwargs_packed) kwargs_ref = { + BlockKwargs.device: device, BlockKwargs.sequence_first: False, BlockKwargs.hidden_dims: (hidden_dim,), } @@ -215,13 +186,13 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: out_batch = [] length = sequence_lengths[b] if sequence_first: - ref_seqs = torch.split(packed[:, b].unsqueeze(0), length.tolist(), dim=1) + ref_seqs = torch.split(packed[:, b].unsqueeze(0), length, dim=1) else: - ref_seqs = torch.split(packed[b].unsqueeze(0), length.tolist(), dim=1) + ref_seqs = torch.split(packed[b].unsqueeze(0), length, dim=1) for seq in ref_seqs: kwargs_ref_seq = { **kwargs_ref, - BlockKwargs.sequence_q_dim: TensorDim("sequence_q", seq.shape[1]), + BlockKwargs.sequence_lengths: [seq.shape[1]], } out_batch.append(mixer_ref(seq, kwargs_ref_seq)) ref_outs.append(torch.cat(out_batch, dim=1)) @@ -236,7 +207,13 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): if param.requires_grad: - assert torch.allclose(_param_grad(param), _param_grad(param_ref), atol=1e-3, rtol=1e-3) + torch.testing.assert_close( + _param_grad(param), + _param_grad(param_ref), + atol=1e-3, + rtol=1e-3, + msg=f"Grad mismatch for parameter {name}", + ) if __name__ == "__main__": From 8938a1d63174e810b246233dbd9b512be9f5f573 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 20:32:10 +0000 Subject: [PATCH 25/43] varlen test --- tests/test_varlen.py | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/tests/test_varlen.py b/tests/test_varlen.py index e6ad61ecf..f69268bb7 100644 --- a/tests/test_varlen.py +++ b/tests/test_varlen.py @@ -7,7 +7,8 @@ from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.decoder.config import MixerConfig from fast_llm.layers.ssm import gdn as gdn_module -from fast_llm.layers.ssm.config import GatedDeltaNetConfig +from fast_llm.layers.ssm import kda as kda_module +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig @pytest.fixture @@ -141,17 +142,33 @@ def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: # TODO: include mamba varlen @pytest.mark.slow @pytest.mark.skipif(not torch.cuda.is_available(), reason="Varlen test needs CUDA") -@pytest.mark.skipif( - gdn_module.chunk_gated_delta_rule is None, - reason="Gated Delta Net fused kernels not available", -) @pytest.mark.parametrize( "config, sequence_first", [ - pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), False), - pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), True), - # pytest.param(KimiDeltaAttentionConfig(heads=4, head_dim=16), False), - # pytest.param(KimiDeltaAttentionConfig(heads=4, head_dim=16), True), + pytest.param( + GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), + False, + marks=pytest.mark.skipif( + gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" + ), + ), + pytest.param( + GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), + True, + marks=pytest.mark.skipif( + gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" + ), + ), + pytest.param( + KimiDeltaAttentionConfig(heads=4, head_dim=16), + False, + marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), + ), + pytest.param( + KimiDeltaAttentionConfig(heads=4, head_dim=16), + True, + marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), + ), ], ) def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: bool, distributed_config, distributed): @@ -227,13 +244,12 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): if param.requires_grad: - torch.testing.assert_close( + assert torch.allclose( _param_grad(param), _param_grad(param_ref), - atol=1e-3, + atol=2e-3, rtol=1e-3, - msg=f"Grad mismatch for parameter {name}", - ) + ), f"Grad mismatch for parameter {name}" if __name__ == "__main__": From 6c2bd46fafa4a102463be62cb2d6c8043b319aa8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 20:46:42 +0000 Subject: [PATCH 26/43] wip --- tests/layers/test_kda_equivalence.py | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/layers/test_kda_equivalence.py diff --git a/tests/layers/test_kda_equivalence.py b/tests/layers/test_kda_equivalence.py new file mode 100644 index 000000000..04ad75ad0 --- /dev/null +++ b/tests/layers/test_kda_equivalence.py @@ -0,0 +1,29 @@ +import torch + +from fast_llm.engine.distributed.distributed import Distributed + +try: + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextConfig, Qwen3NextGatedDeltaNet +except ImportError: + Qwen3NextConfig, Qwen3NextGatedDeltaNet = None, None + + +def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: + """ + Instantiate meta-allocated parameters on the requested device so the layer can run standalone. + """ + for name, param in module.named_parameters(): + if param.device.type != "meta": + continue + param_data = torch.empty_like(param, device=device) + param.init_parameter(param_data, distributed) + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + target = module + if module_path is not None: + for part in module_path.split("."): + target = getattr(target, part) + new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + new_param.grad = None + new_param.grad_buffer = torch.zeros_like(param_data) + new_param.param_grad_is_zero = True + target._parameters[param_name] = new_param From 28a6176ea84474dbff23658a5621b07a826b3e66 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 20:53:18 +0000 Subject: [PATCH 27/43] wip --- .../modeling_apriel_hybrid_ssm.py | 242 +++++++++++++++++- 1 file changed, 239 insertions(+), 3 deletions(-) diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index e63584433..624390a46 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -23,9 +23,12 @@ from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig -# from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn as varlen_selective_scan_fn -# from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as varlen_causal_conv1d_fn - +try: + from fla.modules import FusedRMSNormGated, ShortConvolution + from fla.ops.kda import chunk_kda, fused_recurrent_kda + from fla.ops.kda.gate import fused_kda_gate +except ImportError: + raise ImportError("Plese run `pip install -U fla-core`") logger = logging.get_logger(__name__) @@ -45,6 +48,239 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def index_first_axis(x, indices): + other_shape = x.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(x, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + +def index_put_first_axis(x, indices, first_axis_dim): + y = torch.zeros(first_axis_dim, *x.shape[1:], device=x.device, dtype=x.dtype) + # TODO [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + y[indices] = x + # y.scatter_(0, repeat(indices, 'z -> z d', d=x.shape[1]), x) + return y + + +@tensor_cache +def get_unpad_data( + attention_mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, int]: + lens = prepare_lens_from_mask(attention_mask) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = lens.max().item() + cu_seqlens = prepare_cu_seqlens_from_mask(attention_mask) + return indices, cu_seqlens, max_seqlen_in_batch + + +def unpad_input( + q: torch.Tensor, + states: tuple[torch.Tensor], + attention_mask: torch.Tensor, + q_len: int, + keepdim: bool = False, +): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) + batch_size, seq_len, *_ = states[0].shape + + state = tuple(index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k) for s in states) + + if q_len == seq_len: + q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif q_len == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) + indices_q = cu_seqlens_q[:-1] + q = q.squeeze(1) + else: + raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)") + + if keepdim: + q = q.unsqueeze(0) + state = tuple(s.unsqueeze(0) for s in state) + + return ( + q, + state, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def pad_input( + hidden_states: torch.Tensor, + indices: torch.LongTensor, + batch_size: int, + seq_len: int, +) -> torch.Tensor: + output = index_put_first_axis(hidden_states, indices, batch_size * seq_len) + return rearrange(output, "(b s) ... -> b s ...", b=batch_size) + + +class KimiDeltaAttention(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__() + self.config = config + self.mode = "chunk" + + self.hidden_size = config.hidden_size + self.conv_size = config.short_conv_kernel_size + self.head_dim = config.head_dim + self.num_heads = config.num_heads + self.head_k_dim = self.head_dim + self.num_k_heads = self.num_heads + + self.layer_idx = layer_idx + + assert self.mode in ["chunk", "fused_recurrent"], f"Not suppoerted mode `{self.mode}`." + + projection_k_size = self.head_k_dim * self.num_k_heads + projection_size = self.head_dim * self.num_heads + + self.q_proj = nn.Linear(self.hidden_size, projection_k_size, bias=False) + self.k_proj = nn.Linear(self.hidden_size, projection_k_size, bias=False) + self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False) + + self.q_conv1d = ShortConvolution( + hidden_size=projection_k_size, + kernel_size=self.conv_size, + activation="silu", + ) + self.k_conv1d = ShortConvolution( + hidden_size=projection_k_size, + kernel_size=self.conv_size, + activation="silu", + ) + self.v_conv1d = ShortConvolution( + hidden_size=projection_size, + kernel_size=self.conv_size, + activation="silu", + ) + + self.A_log = torch.nn.Parameter( + torch.log(torch.empty(self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1) + ) + + self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False) + self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False) + + self.dt_bias = nn.Parameter(torch.empty(projection_size, dtype=torch.float32)) + + self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False) + + self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False) + self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False) + + self.o_norm = FusedRMSNormGated(self.head_dim, eps=config.rms_norm_eps, activation="sigmoid") + self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + cache_params=None, + **kwargs: Unpack[dict], + ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: + if attention_mask is not None: + if attention_mask.dim() != 2: + attention_mask = kwargs.get("padding_mask") + + if attention_mask is not None and attention_mask.dim() != 2: + raise ValueError( + "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] " + "(0 = padding). 3D masks are not supported here.", + ) + use_cache = cache_params is not None + batch_size, q_len, _ = hidden_states.shape + mode = "fused_recurrent" if q_len <= 64 else self.mode + if self.training: + assert mode == "chunk", "Only chunk mode is supported in training." + + cu_seqlens = kwargs.get("cu_seqlens") + indices = None + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + conv_state_q, conv_state_k, conv_state_v = None, None, None + recurrent_state = None + if cache_params is not None: + if cache_params.conv_states[self.layer_idx] is not None: + conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + g = self.f_b_proj(self.f_a_proj(hidden_states)) + g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias) + beta = self.b_proj(hidden_states).float().sigmoid() + + q, k = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_k_dim), (q, k)) + v = rearrange(v, "... (h d) -> ... h d", d=self.head_dim) + + if mode == "chunk": + o, recurrent_state = chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + else: + o, recurrent_state = fused_recurrent_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = recurrent_state + cache_params.conv_states[self.layer_idx] = (conv_state_q, conv_state_k, conv_state_v) + + g = self.g_b_proj(self.g_a_proj(hidden_states)) + g = rearrange(g, "... (h d) -> ... h d", d=self.head_dim) + o = self.o_norm(o, g) + + o = rearrange(o, "b t h d -> b t (h d)") + o = self.o_proj(o) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o + + class HybridMambaAttentionStaticCache(Cache): def __init__(self, config: AprielHybridSSMConfig, batch_size, max_length, dtype=torch.float16, device=None): super().__init__() # config, batch_size, max_length, device, dtype) From 2a30bac24ffb0002d81abf0438b2088fbb03ba8e Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 21:40:41 +0000 Subject: [PATCH 28/43] wip --- .../configuration_apriel_hybrid_ssm.py | 40 +++++- .../modeling_apriel_hybrid_ssm.py | 4 +- tests/layers/test_kda_equivalence.py | 126 +++++++++++++++++- 3 files changed, 166 insertions(+), 4 deletions(-) diff --git a/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py index 12ee343ef..7f61fcda5 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py @@ -28,10 +28,37 @@ } +class AprielGDNConfig: + def __init__( + self, + linear_num_key_heads=16, + linear_num_value_heads=32, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + kl_short_conv_kernel_size=4, + kl_num_heads=32, + kl_head_dim=128, + ): + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_conv_kernel_dim = linear_conv_kernel_dim + + # Kimi LInear + self.short_conv_kernel_size = kl_short_conv_kernel_size + self.head_dim = kl_head_dim + self.num_heads = kl_num_heads + + +LAYER_TYPES = {"t": "full_attention", "swa": "sliding_attention", "gdn": "gated_delta_net", "kl": "kimi_linear"} + + class AprielHybridSSMConfig(MistralConfig): model_type = "apriel_hybrid_ssm" - def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + def __init__(self, hybrid_block_layout=["t"], ssm_cfg=None, gdn_cfg=None, **kwargs): super().__init__(**kwargs) self.hybrid_block_layout = hybrid_block_layout self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 @@ -40,3 +67,14 @@ def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): for k, v in ssm_config_default.items(): if k not in self.ssm_cfg: self.ssm_cfg[k] = v # to make sure all elements are present in the config + + gdn_config: AprielGDNConfig = ( + AprielGDNConfig(**gdn_cfg) if isinstance(gdn_cfg, dict) else gdn_cfg or AprielGDNConfig() + ) + # for compatibility with vllm + self.layer_types = [LAYER_TYPES[lt] for lt in hybrid_block_layout] # this is for vllm compatibility + self.linear_attn_config = { + "short_conv_kernel_size": gdn_config.short_conv_kernel_size, + "head_dim": gdn_config.head_dim, + "num_heads": gdn_config.num_heads, + } diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index 624390a46..a91bae82a 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -27,6 +27,8 @@ from fla.modules import FusedRMSNormGated, ShortConvolution from fla.ops.kda import chunk_kda, fused_recurrent_kda from fla.ops.kda.gate import fused_kda_gate + from fla.ops.utils.index import prepare_cu_seqlens_from_mask, prepare_lens_from_mask + from fla.utils import tensor_cache except ImportError: raise ImportError("Plese run `pip install -U fla-core`") @@ -126,7 +128,7 @@ def pad_input( class KimiDeltaAttention(nn.Module): - def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + def __init__(self, config: AprielHybridSSMConfig, layer_idx: int): super().__init__() self.config = config self.mode = "chunk" diff --git a/tests/layers/test_kda_equivalence.py b/tests/layers/test_kda_equivalence.py index 04ad75ad0..97c0a102b 100644 --- a/tests/layers/test_kda_equivalence.py +++ b/tests/layers/test_kda_equivalence.py @@ -1,11 +1,18 @@ +import pytest import torch +import fast_llm.layers.ssm.kda as kda_module +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig try: - from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextConfig, Qwen3NextGatedDeltaNet + from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig + from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention except ImportError: - Qwen3NextConfig, Qwen3NextGatedDeltaNet = None, None + AprielHybridSSMConfig, KimiDeltaAttention = None, None def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: @@ -27,3 +34,118 @@ def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed new_param.grad_buffer = torch.zeros_like(param_data) new_param.param_grad_is_zero = True target._parameters[param_name] = new_param + + +@pytest.mark.slow +@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA equivalence test needs CUDA") +@pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") +@pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") +def test_fast_llm_kda_matches_apriel_forward(): + torch.manual_seed(0) + device = torch.device("cuda") + dtype = torch.bfloat16 + + hidden_size = 16 + seq_len = 6 + num_heads = 4 + head_dim = 4 + kernel_size = 4 + + hf_config = AprielHybridSSMConfig( + hidden_size=hidden_size, + num_attention_heads=num_heads, + num_hidden_layers=1, + rms_norm_eps=1e-6, + ) + # Populate fields expected by the HF implementation. + hf_config.short_conv_kernel_size = kernel_size + hf_config.head_dim = head_dim + hf_config.num_heads = num_heads + hf_layer = KimiDeltaAttention(hf_config, layer_idx=0).to(device=device, dtype=dtype).eval() + + fast_config = KimiDeltaAttentionConfig( + heads=num_heads, + head_dim=head_dim, + convolution_layer={"kernel_size": kernel_size, "activation": "silu"}, + normalization={"epsilon": hf_config.rms_norm_eps}, + ) + distributed_config = DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + hidden_dim = TensorDim("hidden", hidden_size) + fast_layer = fast_config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + distributed = Distributed(config=distributed_config) + fast_layer.setup(distributed) + _materialize_mixer_tensors(fast_layer, distributed, device) + fast_layer.to(device=device, dtype=dtype).eval() + + with torch.no_grad(): + fast_layer.q_proj.weight.copy_(hf_layer.q_proj.weight) + fast_layer.k_proj.weight.copy_(hf_layer.k_proj.weight) + fast_layer.v_proj.weight.copy_(hf_layer.v_proj.weight) + fast_layer.q_conv.weight.copy_(hf_layer.q_conv1d.weight) + fast_layer.k_conv.weight.copy_(hf_layer.k_conv1d.weight) + fast_layer.v_conv.weight.copy_(hf_layer.v_conv1d.weight) + if fast_layer.q_conv.bias is not None and hf_layer.q_conv1d.bias is not None: + fast_layer.q_conv.bias.copy_(hf_layer.q_conv1d.bias) + if fast_layer.k_conv.bias is not None and hf_layer.k_conv1d.bias is not None: + fast_layer.k_conv.bias.copy_(hf_layer.k_conv1d.bias) + if fast_layer.v_conv.bias is not None and hf_layer.v_conv1d.bias is not None: + fast_layer.v_conv.bias.copy_(hf_layer.v_conv1d.bias) + fast_layer.f_a_proj.weight.copy_(hf_layer.f_a_proj.weight) + fast_layer.f_b_proj.weight.copy_(hf_layer.f_b_proj.weight) + fast_layer.g_a_proj.weight.copy_(hf_layer.g_a_proj.weight) + fast_layer.g_b_proj.weight.copy_(hf_layer.g_b_proj.weight) + fast_layer.beta_proj.weight.copy_(hf_layer.b_proj.weight) + fast_layer.o_proj.weight.copy_(hf_layer.o_proj.weight) + fast_layer.A_log.copy_(hf_layer.A_log.reshape_as(fast_layer.A_log)) + fast_layer.dt_bias.copy_(hf_layer.dt_bias.reshape_as(fast_layer.dt_bias)) + fast_layer.norm.weight.copy_(hf_layer.o_norm.weight) + + param_map = { + "q_proj.weight": "q_proj.weight", + "k_proj.weight": "k_proj.weight", + "v_proj.weight": "v_proj.weight", + "q_conv.weight": "q_conv1d.weight", + "k_conv.weight": "k_conv1d.weight", + "v_conv.weight": "v_conv1d.weight", + "f_a_proj.weight": "f_a_proj.weight", + "f_b_proj.weight": "f_b_proj.weight", + "g_a_proj.weight": "g_a_proj.weight", + "g_b_proj.weight": "g_b_proj.weight", + "beta_proj.weight": "b_proj.weight", + "o_proj.weight": "o_proj.weight", + "A_log": "A_log", + "dt_bias": "dt_bias", + "norm.weight": "o_norm.weight", + } + for fast_name, hf_name in param_map.items(): + fast_param = fast_layer.state_dict()[fast_name] + hf_param = hf_layer.state_dict()[hf_name] + if fast_param.shape != hf_param.shape: + hf_param = hf_param.reshape_as(fast_param) + torch.testing.assert_close(fast_param, hf_param, atol=1e-6, rtol=1e-6) + + hidden_states = torch.randn(2, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=False) + + hf_out = hf_layer(hidden_states) + + sequence_lengths = [[seq_len] for _ in range(hidden_states.size(0))] + fast_kwargs = { + BlockKwargs.device: device, + BlockKwargs.sequence_first: False, + BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.hidden_dims: (hidden_dim,), + } + fast_layer.preprocess(fast_kwargs) + fast_out = fast_layer(hidden_states, fast_kwargs) + + torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__]) From cad93ab48d47c2fe3ca52b46b149046c4ccc7fa8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 27 Nov 2025 14:19:41 +0000 Subject: [PATCH 29/43] kda equivalence test --- fast_llm/functional/config.py | 3 +++ fast_llm/layers/common/normalization/config.py | 7 +++++++ fast_llm/layers/common/normalization/normalization.py | 5 ++--- fast_llm/layers/ssm/config.py | 2 +- fast_llm/layers/ssm/kda.py | 2 +- .../apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py | 8 +++++--- tests/layers/test_kda_equivalence.py | 7 ++++--- 7 files changed, 23 insertions(+), 11 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 684193848..e5b59a572 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -42,6 +42,7 @@ class ActivationType(enum.StrEnum): gelu = "gelu" silu = "silu" relu = "relu" + sigmoid = "sigmoid" squared_relu = "squared_relu" identity = "identity" @@ -70,6 +71,7 @@ def _set_activation_fn_map() -> None: ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, + ActivationType.sigmoid: torch.nn.functional.sigmoid, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), ActivationType.identity: lambda x: x, } @@ -83,6 +85,7 @@ def _set_activation_fn_map() -> None: ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", ActivationType.identity: "identity", + ActivationType.sigmoid: "sigmoid", } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 4ecb7a3be..4b8edaebe 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -5,6 +5,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import ParameterConfig, combine_lr_scales +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -133,6 +134,12 @@ def module_class(self): class GatedRMSNormalizationConfig(RMSNormalizationConfig): _abstract = False + activation: ActivationType = Field( + default=ActivationType.silu, + desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + hint=FieldHint.core, + ) + @property def module_class(self): from fast_llm.layers.common.normalization.normalization import GatedRMSNormalization diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 651d8e4b1..b1e875707 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -1,7 +1,6 @@ import abc import torch -import torch.nn.functional as F from fast_llm.config import Configurable from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ @@ -324,7 +323,7 @@ def _forward_fla(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor gate, self.weight, None, - activation="silu", + activation=self._config.activation.hf_name, eps=self._config.epsilon, residual=None, prenorm=False, @@ -333,4 +332,4 @@ def _forward_fla(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor def _forward_local(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: normalized = self._forward(input_) - return normalized * F.silu(gate) + return normalized * self._config.activation.activation_fn(gate) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index f9db4d350..24d0c3928 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -202,7 +202,7 @@ def _validate(self) -> None: if "epsilon" not in self.normalization._explicit_fields: self.normalization.epsilon = 1.0e-5 if "activation" not in self.convolution_layer._explicit_fields: - self.convolution_layer.activation = "silu" + self.convolution_layer.activation = "sigmoid" if "kernel_size" not in self.convolution_layer._explicit_fields: self.convolution_layer.kernel_size = 4 diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index e444df908..d67152143 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -256,8 +256,8 @@ def _forward( g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) if sequence_first: g_kernel = g_kernel.transpose(0, 1) - g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) g_kernel = self._reshape_heads(g_kernel) + g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) g_kernel = fused_kda_gate(g_kernel, self.A_log.float(), dt_bias=self.dt_bias) diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index a91bae82a..1aca40da0 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -237,7 +237,9 @@ def forward( cu_seqlens=cu_seqlens, ) g = self.f_b_proj(self.f_a_proj(hidden_states)) - g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias) + g = rearrange(g, "... (h d) -> ... h d", d=self.head_dim) + + g = fused_kda_gate(g, self.A_log.float().squeeze(), dt_bias=self.dt_bias) beta = self.b_proj(hidden_states).float().sigmoid() q, k = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_k_dim), (q, k)) @@ -250,8 +252,8 @@ def forward( v=v, g=g, beta=beta, - initial_state=recurrent_state, - output_final_state=True, + initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, ) diff --git a/tests/layers/test_kda_equivalence.py b/tests/layers/test_kda_equivalence.py index 97c0a102b..8cc19af92 100644 --- a/tests/layers/test_kda_equivalence.py +++ b/tests/layers/test_kda_equivalence.py @@ -46,7 +46,7 @@ def test_fast_llm_kda_matches_apriel_forward(): dtype = torch.bfloat16 hidden_size = 16 - seq_len = 6 + seq_len = 65 num_heads = 4 head_dim = 4 kernel_size = 4 @@ -67,7 +67,7 @@ def test_fast_llm_kda_matches_apriel_forward(): heads=num_heads, head_dim=head_dim, convolution_layer={"kernel_size": kernel_size, "activation": "silu"}, - normalization={"epsilon": hf_config.rms_norm_eps}, + normalization={"epsilon": hf_config.rms_norm_eps, "activation": "sigmoid"}, ) distributed_config = DistributedConfig( tensor_parallel=1, @@ -128,10 +128,11 @@ def test_fast_llm_kda_matches_apriel_forward(): hf_param = hf_layer.state_dict()[hf_name] if fast_param.shape != hf_param.shape: hf_param = hf_param.reshape_as(fast_param) + print(f"Comparing parameter {fast_name} with shape {fast_param.shape}") torch.testing.assert_close(fast_param, hf_param, atol=1e-6, rtol=1e-6) hidden_states = torch.randn(2, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=False) - + hf_layer.training = True hf_out = hf_layer(hidden_states) sequence_lengths = [[seq_len] for _ in range(hidden_states.size(0))] From 8f957a438983e9aead2bfe310a19d7a583b743da Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 27 Nov 2025 14:35:59 +0000 Subject: [PATCH 30/43] nightly requirements --- Dockerfile | 30 +++++++++++++++++++++++++++--- requirements-kda-nightly.txt | 13 +++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) create mode 100644 requirements-kda-nightly.txt diff --git a/Dockerfile b/Dockerfile index 6bc900ae7..0ae20efe1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,8 @@ # syntax=docker/dockerfile:1.7-labs FROM nvcr.io/nvidia/pytorch:25.05-py3 +ARG KDA_NIGHTLY=0 +ARG TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" +ENV KDA_NIGHTLY=${KDA_NIGHTLY} TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} # Install dependencies. RUN apt-get update \ @@ -29,8 +32,24 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) -RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" +RUN if [ "$KDA_NIGHTLY" = "1" ]; then \ + pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 && \ + pip uninstall -y triton pytorch-triton && \ + pip install -U triton-nightly --index-url https://pypi.fla-org.com/simple; \ + fi + +RUN if [ "$KDA_NIGHTLY" = "1" ]; then \ + MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"; \ + else \ + MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"; \ + fi +RUN if [ "$KDA_NIGHTLY" = "1" ]; then \ + MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"; \ + else \ + MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"; \ + fi +# Optional KDA nightly requirements file for reproducibility. +COPY --chmod=777 requirements-kda-nightly.txt ./ # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ @@ -38,7 +57,12 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0 +RUN if [ "$KDA_NIGHTLY" = "1" ]; then \ + pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" && \ + MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: flash-attn; \ + else \ + pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0; \ + fi # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/requirements-kda-nightly.txt b/requirements-kda-nightly.txt new file mode 100644 index 000000000..89e3b67bd --- /dev/null +++ b/requirements-kda-nightly.txt @@ -0,0 +1,13 @@ +--index-url https://download.pytorch.org/whl/nightly/cu128 +--extra-index-url https://pypi.org/simple +--extra-index-url https://pypi.fla-org.com/simple + +# Core nightly stack +--pre torch +triton-nightly + +# KDA deps compiled against the nightly toolchain +flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main +causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1 +mamba_ssm[causal-conv1d]==2.2.4 +flash-attn==2.7.3 From 82c9cc4a341f1a560cc4acfea33265a6761f1efc Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 27 Nov 2025 14:44:16 +0000 Subject: [PATCH 31/43] docker --- .dockerignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.dockerignore b/.dockerignore index 500fbe11c..3bb9ed756 100644 --- a/.dockerignore +++ b/.dockerignore @@ -11,6 +11,7 @@ !tools !tests !pyproject.toml +!requirements-kda-nightly.txt # Exclude Python cache directories and shared object files within included directories **/__pycache__/ From d25994eacab1d299d4f994a5adaac027d25ac1fc Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 27 Nov 2025 16:24:34 +0000 Subject: [PATCH 32/43] manual build --- .github/workflows/manual-build.yml | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/.github/workflows/manual-build.yml b/.github/workflows/manual-build.yml index 8240087a2..14d30404b 100644 --- a/.github/workflows/manual-build.yml +++ b/.github/workflows/manual-build.yml @@ -22,6 +22,14 @@ on: required: false default: true type: boolean + kda_nightly: + description: 'Enable KDA nightly builds (1 to enable, 0 to disable)' + required: false + default: '0' + type: choice + options: + - '1' + - '0' jobs: manual-docker-build: @@ -34,12 +42,12 @@ jobs: sudo rm -rf /usr/share/dotnet || true sudo rm -rf /opt/ghc || true sudo rm -rf /usr/local/.ghcup || true - + - name: Checkout repository uses: actions/checkout@v4 with: ref: ${{ inputs.commit_sha != '' && inputs.commit_sha || inputs.branch }} - + - name: Get commit info id: commit_info run: | @@ -48,7 +56,7 @@ jobs: echo "full_sha=${COMMIT_SHA}" >> $GITHUB_OUTPUT echo "short_sha=${COMMIT_SHORT}" >> $GITHUB_OUTPUT echo "Building from commit: ${COMMIT_SHA}" - + - name: Docker meta id: meta uses: docker/metadata-action@v5 @@ -59,10 +67,10 @@ jobs: type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }} type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}-${{ steps.commit_info.outputs.short_sha }} type=raw,value=latest-${{ inputs.tag_suffix }},enable=${{ inputs.branch == 'main' && inputs.commit_sha == '' }} - + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - + - name: Login to GHCR if: ${{ inputs.push_image }} uses: docker/login-action@v3 @@ -70,7 +78,7 @@ jobs: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - + - name: Build and push uses: docker/build-push-action@v6 with: @@ -80,7 +88,9 @@ jobs: labels: ${{ steps.meta.outputs.labels }} cache-from: type=registry,ref=ghcr.io/servicenow/fast-llm:cache cache-to: type=registry,ref=ghcr.io/servicenow/fast-llm:cache,mode=max - + build-args: | + KDA_NIGHTLY=${{ inputs.kda_nightly }} + - name: Output build info run: | echo "Built Docker image with tags:" From 5a44097642b685924263f9938e878f1289faaf26 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 1 Dec 2025 14:07:29 -0500 Subject: [PATCH 33/43] two docker files --- Dockerfile | 32 +++------------------ Dockerfile.kda-nightly | 64 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 28 deletions(-) create mode 100644 Dockerfile.kda-nightly diff --git a/Dockerfile b/Dockerfile index 0ae20efe1..58df12e3b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,5 @@ # syntax=docker/dockerfile:1.7-labs FROM nvcr.io/nvidia/pytorch:25.05-py3 -ARG KDA_NIGHTLY=0 -ARG TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" -ENV KDA_NIGHTLY=${KDA_NIGHTLY} TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} # Install dependencies. RUN apt-get update \ @@ -32,24 +29,8 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) -RUN if [ "$KDA_NIGHTLY" = "1" ]; then \ - pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 && \ - pip uninstall -y triton pytorch-triton && \ - pip install -U triton-nightly --index-url https://pypi.fla-org.com/simple; \ - fi - -RUN if [ "$KDA_NIGHTLY" = "1" ]; then \ - MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"; \ - else \ - MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"; \ - fi -RUN if [ "$KDA_NIGHTLY" = "1" ]; then \ - MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"; \ - else \ - MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"; \ - fi -# Optional KDA nightly requirements file for reproducibility. -COPY --chmod=777 requirements-kda-nightly.txt ./ +RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" +RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ @@ -57,12 +38,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN if [ "$KDA_NIGHTLY" = "1" ]; then \ - pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" && \ - MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: flash-attn; \ - else \ - pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0; \ - fi +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM @@ -75,4 +51,4 @@ COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/ # Set a dummy default user so we don't run in root by default. # The image is still compatible with any user id. RUN useradd user -USER user +USER user \ No newline at end of file diff --git a/Dockerfile.kda-nightly b/Dockerfile.kda-nightly new file mode 100644 index 000000000..0480f52e2 --- /dev/null +++ b/Dockerfile.kda-nightly @@ -0,0 +1,64 @@ +# syntax=docker/dockerfile:1.7-labs +FROM nvcr.io/nvidia/pytorch:25.05-py3 +ARG KDA_NIGHTLY=1 +ARG TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" +ENV KDA_NIGHTLY=${KDA_NIGHTLY} TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} + +# Install dependencies. +RUN apt-get update \ + && apt-get install --no-install-recommends -y acl git-lfs \ + && rm -rf /var/lib/apt/lists/* \ + && git lfs install + +# Set the working directory. +WORKDIR /app +# Set the permission to 777 for all files and directories in `/app`, `/home` and python install directories: +# 1. Create directories explicitly because docker use the wrong permission for explicit creation. +# 2. For the rest, set the default ACL to 777 for all users. +RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/tools \ + && setfacl -m d:u::rwx,d:g::rwx,d:o::rwx,u::rwx,g::rwx,o::rwx \ + /app \ + /home \ + /usr \ + /usr/local \ + /usr/local/bin \ + /usr/local/lib \ + /usr/local/lib/python3.12 \ + /usr/local/lib/python3.12/dist-packages \ + /usr/local/lib/python3.12/dist-packages/__pycache__ + +# The base image enforces versions for things like pytest for no good reason. +ENV PIP_CONSTRAINT="" +# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. +# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) +# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) +RUN pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 \ + && pip uninstall -y triton pytorch-triton \ + && pip install -U triton-nightly --index-url https://pypi.fla-org.com/simple + +RUN MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" +RUN MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" +# Optional KDA nightly requirements file for reproducibility. +COPY --chmod=777 requirements-kda-nightly.txt ./ +# Copy dependency files with universal write permissions for all users. +COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ +COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ +COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ +COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ + +# Install dependencies within the virtual environment. +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" \ + && MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: flash-attn + +# Copy the remaining source code with universal write permissions. +COPY --chmod=777 ./Megatron-LM Megatron-LM +COPY --chmod=777 ./examples examples +COPY --chmod=777 ./tests tests +COPY --chmod=777 ./tools tools +COPY --chmod=777 ./fast_llm_external_models fast_llm_external_models +COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/ + +# Set a dummy default user so we don't run in root by default. +# The image is still compatible with any user id. +RUN useradd user +USER user From a164a2b080e84af835e6dceb2d91d3b8d85f7872 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 2 Dec 2025 16:49:31 +0000 Subject: [PATCH 34/43] test import fix --- tests/utils/model_configs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 73f0504fc..00375bba5 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -15,6 +15,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.conversion.config import ( + Apriel2CheckpointFormat, AprielHybridSSMCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, From c4aa9b159c74742abde731718f018a4a1f5e3319 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 2 Dec 2025 18:05:17 +0000 Subject: [PATCH 35/43] set correct activations --- fast_llm/layers/ssm/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 24d0c3928..8c85ce8be 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -202,9 +202,11 @@ def _validate(self) -> None: if "epsilon" not in self.normalization._explicit_fields: self.normalization.epsilon = 1.0e-5 if "activation" not in self.convolution_layer._explicit_fields: - self.convolution_layer.activation = "sigmoid" + self.convolution_layer.activation = "silu" if "kernel_size" not in self.convolution_layer._explicit_fields: self.convolution_layer.kernel_size = 4 + if "activation" not in self.normalization._explicit_fields: + self.normalization.activation = "sigmoid" super()._validate() From a8849cbd1b12bfbea37e807d4c6e817321fea811 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 2 Dec 2025 19:05:25 +0000 Subject: [PATCH 36/43] import --- fast_llm/layers/ssm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 8c85ce8be..1a5d5274b 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -193,7 +193,7 @@ class KimiDeltaAttentionConfig(MixerConfig): @property def layer_class(self) -> "type": - from kda import KimiDeltaAttention + from fast_llm.layers.ssm.kda import KimiDeltaAttention return KimiDeltaAttention From 5f32ba745daf9ab00e470673b4fb1b977168b8c4 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 5 Dec 2025 11:49:10 -0500 Subject: [PATCH 37/43] kda docker file --- Dockerfile.kda-nightly | 88 ++++++++++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 28 deletions(-) diff --git a/Dockerfile.kda-nightly b/Dockerfile.kda-nightly index 0480f52e2..0a5dc7e40 100644 --- a/Dockerfile.kda-nightly +++ b/Dockerfile.kda-nightly @@ -1,15 +1,20 @@ # syntax=docker/dockerfile:1.7-labs -FROM nvcr.io/nvidia/pytorch:25.05-py3 -ARG KDA_NIGHTLY=1 -ARG TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" -ENV KDA_NIGHTLY=${KDA_NIGHTLY} TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} - -# Install dependencies. -RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs \ +# FROM nvcr.io/nvidia/pytorch:25.05-py3 +FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 +ENV TORCH_CUDA_ARCH_LIST="9.0" +# ARG TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" +ENV FLASH_ATTENTION_CUDA_ARCHS="90" +ENV PIP_BREAK_SYSTEM_PACKAGES=1 + +# Basic system deps +RUN apt-get update && apt-get install --no-install-recommends -y \ + python3 python3-pip python3-dev git-lfs build-essential acl \ + libjpeg-dev zlib1g-dev libpng-dev libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 + # Set the working directory. WORKDIR /app # Set the permission to 777 for all files and directories in `/app`, `/home` and python install directories: @@ -17,27 +22,40 @@ WORKDIR /app # 2. For the rest, set the default ACL to 777 for all users. RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/tools \ && setfacl -m d:u::rwx,d:g::rwx,d:o::rwx,u::rwx,g::rwx,o::rwx \ - /app \ - /home \ - /usr \ - /usr/local \ - /usr/local/bin \ - /usr/local/lib \ - /usr/local/lib/python3.12 \ - /usr/local/lib/python3.12/dist-packages \ - /usr/local/lib/python3.12/dist-packages/__pycache__ + /app \ + /home \ + /usr \ + /usr/local \ + /usr/local/bin \ + /usr/local/lib # The base image enforces versions for things like pytest for no good reason. ENV PIP_CONSTRAINT="" -# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. -# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) -# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) -RUN pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 \ - && pip uninstall -y triton pytorch-triton \ - && pip install -U triton-nightly --index-url https://pypi.fla-org.com/simple - -RUN MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" + +RUN python -m pip install --no-cache-dir --pre \ + torch torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/nightly/cu128 + +RUN python -c "import torch, sys; print('PYTHON:', sys.executable); print('TORCH VERSION:', torch.__version__); print('TORCH FILE:', torch.__file__); print('CUDA:', torch.version.cuda)" + +RUN python -m pip install --no-cache-dir packaging \ + && python -m pip install --no-build-isolation --no-cache-dir git+https://github.com/NVIDIA/apex.git + +# Install flash-linear-attention prerequisites and build from source. +RUN python -m pip install --no-cache-dir einops ninja datasets transformers numpy \ + && (python -m pip uninstall -y flash-linear-attention || true) \ + && python -m pip install -U --no-use-pep517 --no-deps --no-cache-dir \ + git+https://github.com/fla-org/flash-linear-attention + +# Optional toolchain pieces for flash-attention. +RUN python -m pip install packaging psutil ninja +RUN python -m pip install --no-cache-dir flit-core packaging + +RUN python -m pip uninstall -y causal-conv1d || true +RUN MAX_JOBS=2 python -m pip install --no-build-isolation --no-cache-dir "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" +RUN MAX_JOBS=2 python -m pip install --no-build-isolation --no-cache-dir --no-deps \ + "mamba_ssm@git+https://github.com/state-spaces/mamba@4a8a2a2" + # Optional KDA nightly requirements file for reproducibility. COPY --chmod=777 requirements-kda-nightly.txt ./ # Copy dependency files with universal write permissions for all users. @@ -46,9 +64,23 @@ COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ +RUN python -m pip install --no-cache-dir pybind11 # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" \ - && MAX_JOBS=2 pip install --no-build-isolation --no-binary :all: flash-attn +RUN python -m pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" + +RUN python -m pip install --no-cache-dir --no-build-isolation \ + -e ".[VISION]" + + +# We only care about H100 (sm_90) +# RUN MAX_JOBS=1 python -m pip install --no-deps --no-cache-dir --no-build-isolation flash-attn \ +# && python -m pip install pytest +RUN MAX_JOBS=1 python -m pip install pytest + +RUN git clone https://github.com/NVIDIA/apex /tmp/apex \ + && cd /tmp/apex \ + && MAX_JOBS=1 APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation . \ + && rm -rf /tmp/apex # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM From d33d6d7f12b386e23f5f0457574c7aab8cc82a28 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 5 Dec 2025 18:29:53 +0000 Subject: [PATCH 38/43] revert workflow change --- .github/workflows/manual-build.yml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/.github/workflows/manual-build.yml b/.github/workflows/manual-build.yml index 14d30404b..2d7eb315c 100644 --- a/.github/workflows/manual-build.yml +++ b/.github/workflows/manual-build.yml @@ -22,14 +22,6 @@ on: required: false default: true type: boolean - kda_nightly: - description: 'Enable KDA nightly builds (1 to enable, 0 to disable)' - required: false - default: '0' - type: choice - options: - - '1' - - '0' jobs: manual-docker-build: @@ -88,8 +80,6 @@ jobs: labels: ${{ steps.meta.outputs.labels }} cache-from: type=registry,ref=ghcr.io/servicenow/fast-llm:cache cache-to: type=registry,ref=ghcr.io/servicenow/fast-llm:cache,mode=max - build-args: | - KDA_NIGHTLY=${{ inputs.kda_nightly }} - name: Output build info run: | From 05abc035579da3485d3409a364594a4382584456 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 5 Dec 2025 18:32:12 +0000 Subject: [PATCH 39/43] removed unused requirements file --- .dockerignore | 1 - Dockerfile.kda-nightly | 6 ++---- requirements-kda-nightly.txt | 13 ------------- 3 files changed, 2 insertions(+), 18 deletions(-) delete mode 100644 requirements-kda-nightly.txt diff --git a/.dockerignore b/.dockerignore index 3bb9ed756..500fbe11c 100644 --- a/.dockerignore +++ b/.dockerignore @@ -11,7 +11,6 @@ !tools !tests !pyproject.toml -!requirements-kda-nightly.txt # Exclude Python cache directories and shared object files within included directories **/__pycache__/ diff --git a/Dockerfile.kda-nightly b/Dockerfile.kda-nightly index 0a5dc7e40..582ac830c 100644 --- a/Dockerfile.kda-nightly +++ b/Dockerfile.kda-nightly @@ -27,7 +27,7 @@ RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/to /usr \ /usr/local \ /usr/local/bin \ - /usr/local/lib + /usr/local/lib # The base image enforces versions for things like pytest for no good reason. ENV PIP_CONSTRAINT="" @@ -56,8 +56,6 @@ RUN MAX_JOBS=2 python -m pip install --no-build-isolation --no-cache-dir "causal RUN MAX_JOBS=2 python -m pip install --no-build-isolation --no-cache-dir --no-deps \ "mamba_ssm@git+https://github.com/state-spaces/mamba@4a8a2a2" -# Optional KDA nightly requirements file for reproducibility. -COPY --chmod=777 requirements-kda-nightly.txt ./ # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ @@ -66,7 +64,7 @@ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ RUN python -m pip install --no-cache-dir pybind11 # Install dependencies within the virtual environment. -RUN python -m pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" +RUN python -m pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" RUN python -m pip install --no-cache-dir --no-build-isolation \ -e ".[VISION]" diff --git a/requirements-kda-nightly.txt b/requirements-kda-nightly.txt deleted file mode 100644 index 89e3b67bd..000000000 --- a/requirements-kda-nightly.txt +++ /dev/null @@ -1,13 +0,0 @@ ---index-url https://download.pytorch.org/whl/nightly/cu128 ---extra-index-url https://pypi.org/simple ---extra-index-url https://pypi.fla-org.com/simple - -# Core nightly stack ---pre torch -triton-nightly - -# KDA deps compiled against the nightly toolchain -flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main -causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1 -mamba_ssm[causal-conv1d]==2.2.4 -flash-attn==2.7.3 From 7b30a360276729bf79448cd9a7a03e9aac21b2b9 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sun, 7 Dec 2025 17:19:20 +0000 Subject: [PATCH 40/43] Bump base image and dependencies for KDA support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update to nvcr.io/nvidia/pytorch:25.11-py3 which includes: - PyTorch 2.10 - CUDA 13.0 - flash-attn 2.7.4.post1 (pre-installed, no compilation needed) Dependency updates: - causal-conv1d: v1.5.4 (was pinned to commit 2a288a1) - mamba-ssm: 2.2.6.post3 (was pinned to commit 4a8a2a2) - flash-linear-attention: pin to commit 67eee20 (was @main) - flash-attn: 2.7.4.post1 to match base image (was 2.7.3) - triton: 3.5.1 in Dockerfile (was 3.1.0) These updates enable Kimi Delta Attention (KDA) support via the flash-linear-attention library. The pinned versions are tested and working, unlike the nightly/unpinned approach in #395. Note: Dropless MoE kernel remains broken with triton >= 3.2.0 and needs a complete rewrite (also limited to 32 experts). This is tracked separately and doesn't block KDA work. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- Dockerfile | 9 +++++---- setup.cfg | 12 ++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/Dockerfile b/Dockerfile index 6bc900ae7..5804d0e47 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1.7-labs -FROM nvcr.io/nvidia/pytorch:25.05-py3 +FROM nvcr.io/nvidia/pytorch:25.11-py3 # Install dependencies. RUN apt-get update \ @@ -29,8 +29,9 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) -RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" +RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d @ git+https://github.com/Dao-AILab/causal-conv1d@v1.5.4" +RUN MAX_JOBS=2 pip install --no-build-isolation mamba-ssm==2.2.6.post3 +RUN MAX_JOBS=2 pip install --no-build-isolation "flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@67eee20c8503cd19eeb52aa1b99821308e9260c5" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ @@ -38,7 +39,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.5.1 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/setup.cfg b/setup.cfg index f4b2c904b..58f8ea2d1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,10 +25,10 @@ CORE = # Used for checkpoints safetensors>=0.5.3 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation - flash-attn==2.7.3 - # Dropless MLP is broken with triton 3.2.0, 3.3.0 and 3.3.1. TODO: Remove once a working triton version is released. - # TODO: Removed because it breaks cpu-only installs and pip dependency resolution. - # triton==3.1.0 + flash-attn==2.7.4.post1 + # Dropless MoE kernel is broken with triton >= 3.2.0 and needs a rewrite (also limited to 32 experts). + # Not pinning triton here as it breaks cpu-only installs and pip dependency resolution. + # triton==3.5.1 # Small packages required for some optional features and tools. @@ -52,8 +52,8 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d]==2.2.4 - flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main + mamba_ssm[causal-conv1d]==2.2.6.post3 + flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@67eee20c8503cd19eeb52aa1b99821308e9260c5 GENERATION = lm_eval>=0.4.9 From eb52cc7c3fe23d18d180fde97481b651658dcb39 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 8 Dec 2025 16:06:25 +0000 Subject: [PATCH 41/43] fixes --- Dockerfile | 2 +- fast_llm/layers/ssm/config.py | 32 +---- fast_llm/layers/ssm/gdn.py | 3 +- fast_llm/layers/ssm/kda.py | 7 +- fast_llm/models/auto.py | 10 +- tests/layers/test_gdn_equivalence.py | 167 +++++++++++++-------------- tests/layers/test_kda_equivalence.py | 106 ++++++++--------- tests/layers/test_lm_head.py | 2 +- tests/utils/model_configs.py | 2 +- 9 files changed, 146 insertions(+), 185 deletions(-) diff --git a/Dockerfile b/Dockerfile index 58df12e3b..6bc900ae7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -51,4 +51,4 @@ COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/ # Set a dummy default user so we don't run in root by default. # The image is still compatible with any user id. RUN useradd user -USER user \ No newline at end of file +USER user diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 1a5d5274b..450591216 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -4,7 +4,6 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig -from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig @@ -15,6 +14,8 @@ import torch from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + from fast_llm.layers.ssm.gdn import GatedDeltaNet + from fast_llm.layers.ssm.kda import KimiDeltaAttention from fast_llm.layers.ssm.mamba import Mamba from fast_llm.tensor import ParameterMeta @@ -84,37 +85,18 @@ class GatedDeltaNetConfig(MixerConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - norm_epsilon: float = Field( - default=1e-6, - desc="Epsilon used by the gated RMS norm.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - activation: ActivationType = Field( - default=ActivationType.silu, - desc="Activation used after the convolution.", - hint=FieldHint.architecture, - ) def _validate(self) -> None: super()._validate() Assert.multiple(self.value_heads, self.key_heads) @property - def layer_class(self) -> "type": + def layer_class(self) -> "type[GatedDeltaNet]": from fast_llm.layers.ssm.gdn import GatedDeltaNet return GatedDeltaNet def _validate(self) -> None: - with self._set_implicit_default(): - if "epsilon" not in self.normalization._explicit_fields: - self.normalization.epsilon = 1.0e-5 - if "activation" not in self.convolution_layer._explicit_fields: - self.convolution_layer.activation = "silu" - if "kernel_size" not in self.convolution_layer._explicit_fields: - self.convolution_layer.kernel_size = 4 - super()._validate() @@ -192,19 +174,13 @@ class KimiDeltaAttentionConfig(MixerConfig): ) @property - def layer_class(self) -> "type": + def layer_class(self) -> "type[KimiDeltaAttention]": from fast_llm.layers.ssm.kda import KimiDeltaAttention return KimiDeltaAttention def _validate(self) -> None: with self._set_implicit_default(): - if "epsilon" not in self.normalization._explicit_fields: - self.normalization.epsilon = 1.0e-5 - if "activation" not in self.convolution_layer._explicit_fields: - self.convolution_layer.activation = "silu" - if "kernel_size" not in self.convolution_layer._explicit_fields: - self.convolution_layer.kernel_size = 4 if "activation" not in self.normalization._explicit_fields: self.normalization.activation = "sigmoid" diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 9f3a55263..40f15837c 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -9,6 +9,7 @@ from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -204,7 +205,7 @@ def __init__( self.convolution = self._config.convolution_layer.get_layer( qkv_channels_dim, default_add_bias=False, - default_activation=self._config.activation, + default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, ) diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index d67152143..323e1ad13 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -8,6 +8,7 @@ from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -102,21 +103,21 @@ def __init__( self.q_conv = self._config.convolution_layer.get_layer( self._projection_dim, default_add_bias=False, - default_activation=self._config.convolution_layer.activation, + default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, ) self.k_conv = self._config.convolution_layer.get_layer( self._projection_dim, default_add_bias=False, - default_activation=self._config.convolution_layer.activation, + default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, ) self.v_conv = self._config.convolution_layer.get_layer( self._projection_dim, default_add_bias=False, - default_activation=self._config.convolution_layer.activation, + default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, ) diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index f7c34a973..e8160ec87 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -3,12 +3,14 @@ """ from fast_llm.layers.attention.config import AttentionConfig # isort: skip -from fast_llm.layers.ssm.config import ( - MambaConfig, - Mamba2Config, +from fast_llm.layers.ssm.config import ( # isort: skip DiscreteMamba2Config, GatedDeltaNetConfig, -) # isort: skip + KimiDeltaAttentionConfig, + Mamba2Config, + MambaConfig, +) + from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.models.multimodal.config import MultiModalModelConfig, MultiModalTrainerConfig # isort: skip diff --git a/tests/layers/test_gdn_equivalence.py b/tests/layers/test_gdn_equivalence.py index 9886056ea..f4d116349 100644 --- a/tests/layers/test_gdn_equivalence.py +++ b/tests/layers/test_gdn_equivalence.py @@ -1,104 +1,89 @@ import pytest import torch -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.functional.config import ActivationType +from fast_llm.config import UpdateType from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.ssm.config import GatedDeltaNetConfig - -try: - from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextConfig, Qwen3NextGatedDeltaNet -except ImportError: - Qwen3NextConfig, Qwen3NextGatedDeltaNet = None, None - - -def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: - """ - Instantiate meta-allocated parameters on the requested device so the layer can run standalone. - """ - for name, param in module.named_parameters(): - if param.device.type != "meta": - continue - param_data = torch.empty_like(param, device=device) - param.init_parameter(param_data, distributed) - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - target = module - if module_path is not None: - for part in module_path.split("."): - target = getattr(target, part) - new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - new_param.grad = None - new_param.grad_buffer = torch.zeros_like(param_data) - new_param.param_grad_is_zero = True - target._parameters[param_name] = new_param +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet +from tests.utils.utils import get_base_model, get_stage, requires_cuda + +VOCAB_SIZE = 500 +HIDDEN_SIZE = 16 +SEQ_LEN = 65 +NUM_V_HEADS = 4 +NUM_K_HEADS = 2 +HEAD_DIM = 4 +KERNEL_SIZE = 4 @pytest.mark.slow -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Varlen test needs CUDA") -@pytest.mark.skipif(Qwen3NextConfig is None, reason="transformers with Qwen3-Next not installed") +@requires_cuda def test_fast_llm_gdn_matches_qwen3_next_forward(): torch.manual_seed(0) device = torch.device("cuda") dtype = torch.bfloat16 - hidden_size = 16 - seq_len = 6 - num_k_heads = 2 - num_v_heads = 4 - head_k_dim = 4 - head_v_dim = 4 - kernel_size = 4 - - hf_config = Qwen3NextConfig( - hidden_size=hidden_size, - linear_num_key_heads=num_k_heads, - linear_num_value_heads=num_v_heads, - linear_key_head_dim=head_k_dim, - linear_value_head_dim=head_v_dim, - linear_conv_kernel_dim=kernel_size, - hidden_act="silu", - rms_norm_eps=1e-6, - dtype=dtype, + config_dict_hf = { + "num_value_heads": NUM_V_HEADS, + "num_key_heads": NUM_K_HEADS, + "key_head_dim": HEAD_DIM, + "value_head_dim": HEAD_DIM, + "conv_kernel_size": KERNEL_SIZE, + "activation": "silu", + "norm_eps": 1e-5, + } + + hf_layer = ( + Apriel2GatedDeltaNet(HIDDEN_SIZE, config_dict_hf, layer_idx=0, dtype=dtype) + .to(device=device, dtype=dtype) + .eval() ) - hf_layer = Qwen3NextGatedDeltaNet(hf_config, layer_idx=0).to(device=device, dtype=dtype).eval() - - fast_config = GatedDeltaNetConfig( - value_heads=num_v_heads, - key_heads=num_k_heads, - value_head_dim=head_v_dim, - key_head_dim=head_k_dim, - activation=ActivationType.silu, - normalization={"epsilon": hf_config.rms_norm_eps}, - convolution_layer={"kernel_size": kernel_size, "activation": ActivationType.silu}, + + config = GPTBaseModelConfig.from_dict( + { + "decoder": { + "num_blocks": 1, + "block": { + "mixer": { + "type": "gdn", + "value_heads": NUM_V_HEADS, + "key_heads": NUM_K_HEADS, + "key_head_dim": HEAD_DIM, + "value_head_dim": HEAD_DIM, + "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, + } + }, + }, + "embeddings": {"vocab_size": VOCAB_SIZE}, + "hidden_size": HIDDEN_SIZE, + }, + update_type=UpdateType.update, ) - distributed_config = DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, + + model, distributed = get_base_model( + GPTModelConfig.from_dict( + { + "base_model": config, + "distributed": {}, + }, + ) ) - hidden_dim = TensorDim("hidden", hidden_size) - fast_layer = fast_config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - distributed = Distributed(config=distributed_config) - fast_layer.setup(distributed) - _materialize_mixer_tensors(fast_layer, distributed, device) + fast_layer = model.decoder[0].mixer + get_stage([fast_layer], distributed, [], {}) fast_layer.to(device=device, dtype=dtype).eval() with torch.no_grad(): - fast_layer.in_proj_qkvz.weight.copy_(hf_layer.in_proj_qkvz.weight) - fast_layer.in_proj_ba.weight.copy_(hf_layer.in_proj_ba.weight) - fast_layer.convolution.weight.copy_(hf_layer.conv1d.weight) - if fast_layer.convolution.bias is not None and hf_layer.conv1d.bias is not None: - fast_layer.convolution.bias.copy_(hf_layer.conv1d.bias) - fast_layer.out_proj.weight.copy_(hf_layer.out_proj.weight) - fast_layer.A_log.copy_(hf_layer.A_log) - fast_layer.dt_bias.copy_(hf_layer.dt_bias) - fast_layer.norm.weight.copy_(hf_layer.norm.weight) - - hidden_states = torch.randn(1, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=False) + fast_layer.in_proj_qkvz.weight.copy_(hf_layer.gdn.in_proj_qkvz.weight) + fast_layer.in_proj_ba.weight.copy_(hf_layer.gdn.in_proj_ba.weight) + fast_layer.convolution.weight.copy_(hf_layer.gdn.conv1d.weight) + if fast_layer.convolution.bias is not None and hf_layer.gdn.conv1d.bias is not None: + fast_layer.convolution.bias.copy_(hf_layer.gdn.conv1d.bias) + fast_layer.out_proj.weight.copy_(hf_layer.gdn.out_proj.weight) + fast_layer.A_log.copy_(hf_layer.gdn.A_log) + fast_layer.dt_bias.copy_(hf_layer.gdn.dt_bias) + fast_layer.norm.weight.copy_(hf_layer.gdn.norm.weight) + + hidden_states = torch.randn(1, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) param_map = { "in_proj_qkvz.weight": "in_proj_qkvz.weight", @@ -110,22 +95,28 @@ def test_fast_llm_gdn_matches_qwen3_next_forward(): "dt_bias": "dt_bias", "norm.weight": "norm.weight", } + hf_state_dict = hf_layer.gdn.state_dict() for k, p in fast_layer.state_dict().items(): - torch.testing.assert_close(p, hf_layer.state_dict()[param_map[k]], atol=1e-6, rtol=1e-6) + torch.testing.assert_close(p, hf_state_dict[param_map[k]], atol=1e-5, rtol=1e-5) # need to monkey patch the hf implementation with our fix_query_key_value_ordering due to the layout differences - hf_layer.fix_query_key_value_ordering = fast_layer.fix_query_key_value_ordering + hf_layer.gdn.fix_query_key_value_ordering = fast_layer.fix_query_key_value_ordering hf_layer._local_key_heads = fast_layer._local_key_heads hf_layer._local_value_heads = fast_layer._local_value_heads hf_layer._config = fast_layer._config - hf_out = hf_layer(hidden_states) + hf_out = hf_layer(hidden_states)[0] + sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] fast_kwargs = { + BlockKwargs.device: device, BlockKwargs.sequence_first: False, - BlockKwargs.hidden_dims: (hidden_dim,), + BlockKwargs.hidden_dims: (HIDDEN_SIZE,), + BlockKwargs.sequence_length: SEQ_LEN, + BlockKwargs.sequence_lengths: sequence_lengths, } - fast_out = fast_layer(hidden_states, fast_kwargs) + fast_layer.preprocess(fast_kwargs) + fast_out, _ = fast_layer(hidden_states, fast_kwargs) torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) diff --git a/tests/layers/test_kda_equivalence.py b/tests/layers/test_kda_equivalence.py index 8cc19af92..feb47a511 100644 --- a/tests/layers/test_kda_equivalence.py +++ b/tests/layers/test_kda_equivalence.py @@ -2,11 +2,10 @@ import torch import fast_llm.layers.ssm.kda as kda_module -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.config import UpdateType from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from tests.utils.utils import get_base_model, get_stage, requires_cuda try: from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig @@ -14,30 +13,16 @@ except ImportError: AprielHybridSSMConfig, KimiDeltaAttention = None, None - -def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: - """ - Instantiate meta-allocated parameters on the requested device so the layer can run standalone. - """ - for name, param in module.named_parameters(): - if param.device.type != "meta": - continue - param_data = torch.empty_like(param, device=device) - param.init_parameter(param_data, distributed) - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - target = module - if module_path is not None: - for part in module_path.split("."): - target = getattr(target, part) - new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - new_param.grad = None - new_param.grad_buffer = torch.zeros_like(param_data) - new_param.param_grad_is_zero = True - target._parameters[param_name] = new_param +VOCAB_SIZE = 500 +HIDDEN_SIZE = 16 +SEQ_LEN = 65 +NUM_HEADS = 4 +HEAD_DIM = 4 +KERNEL_SIZE = 4 @pytest.mark.slow -@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA equivalence test needs CUDA") +@requires_cuda @pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") @pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") def test_fast_llm_kda_matches_apriel_forward(): @@ -45,42 +30,47 @@ def test_fast_llm_kda_matches_apriel_forward(): device = torch.device("cuda") dtype = torch.bfloat16 - hidden_size = 16 - seq_len = 65 - num_heads = 4 - head_dim = 4 - kernel_size = 4 - hf_config = AprielHybridSSMConfig( - hidden_size=hidden_size, - num_attention_heads=num_heads, + hidden_size=HIDDEN_SIZE, + num_attention_heads=NUM_HEADS, num_hidden_layers=1, rms_norm_eps=1e-6, ) - # Populate fields expected by the HF implementation. - hf_config.short_conv_kernel_size = kernel_size - hf_config.head_dim = head_dim - hf_config.num_heads = num_heads + hf_config.short_conv_kernel_size = KERNEL_SIZE + hf_config.head_dim = HEAD_DIM + hf_config.num_heads = NUM_HEADS hf_layer = KimiDeltaAttention(hf_config, layer_idx=0).to(device=device, dtype=dtype).eval() - fast_config = KimiDeltaAttentionConfig( - heads=num_heads, - head_dim=head_dim, - convolution_layer={"kernel_size": kernel_size, "activation": "silu"}, - normalization={"epsilon": hf_config.rms_norm_eps, "activation": "sigmoid"}, + config = GPTBaseModelConfig.from_dict( + { + "decoder": { + "num_blocks": 1, + "block": { + "mixer": { + "type": "kda", + "heads": NUM_HEADS, + "head_dim": HEAD_DIM, + "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, + "normalization": {"epsilon": hf_config.rms_norm_eps, "activation": "sigmoid"}, + } + }, + }, + "embeddings": {"vocab_size": VOCAB_SIZE}, + "hidden_size": HIDDEN_SIZE, + }, + update_type=UpdateType.update, ) - distributed_config = DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, + + model, distributed = get_base_model( + GPTModelConfig.from_dict( + { + "base_model": config, + "distributed": {}, + }, + ) ) - hidden_dim = TensorDim("hidden", hidden_size) - fast_layer = fast_config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - distributed = Distributed(config=distributed_config) - fast_layer.setup(distributed) - _materialize_mixer_tensors(fast_layer, distributed, device) + fast_layer = model.decoder[0].mixer + get_stage([fast_layer], distributed, [], {}) fast_layer.to(device=device, dtype=dtype).eval() with torch.no_grad(): @@ -129,21 +119,21 @@ def test_fast_llm_kda_matches_apriel_forward(): if fast_param.shape != hf_param.shape: hf_param = hf_param.reshape_as(fast_param) print(f"Comparing parameter {fast_name} with shape {fast_param.shape}") - torch.testing.assert_close(fast_param, hf_param, atol=1e-6, rtol=1e-6) + torch.testing.assert_close(fast_param, hf_param, atol=1e-5, rtol=1e-5) - hidden_states = torch.randn(2, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=False) + hidden_states = torch.randn(2, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) hf_layer.training = True hf_out = hf_layer(hidden_states) - sequence_lengths = [[seq_len] for _ in range(hidden_states.size(0))] + sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] fast_kwargs = { BlockKwargs.device: device, BlockKwargs.sequence_first: False, BlockKwargs.sequence_lengths: sequence_lengths, - BlockKwargs.hidden_dims: (hidden_dim,), + BlockKwargs.hidden_dims: (HIDDEN_SIZE,), } fast_layer.preprocess(fast_kwargs) - fast_out = fast_layer(hidden_states, fast_kwargs) + fast_out, _ = fast_layer(hidden_states, fast_kwargs) torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c431bb26d..6383e6aae 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -39,7 +39,7 @@ def _reverse_kl_loss( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask.flatten()).mean() + loss = (loss_per_sample * loss_mask.flatten()).sum() / loss_mask.sum() return loss diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4de60ae1c..843a18e5a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -950,7 +950,7 @@ def _update_and_add_testing_config( _update_and_add_testing_config( - # Tests hybrid with gated delta net mixer. + # Tests hybrid with KDA mixer. "llama", "hybrid_kda", updates={ From 372771bc2f0f90965d39e04ae933b5a95109c483 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 8 Dec 2025 16:23:59 +0000 Subject: [PATCH 42/43] removed kda docker since we probably do not need it --- Dockerfile.kda-nightly | 94 ------------------------------------------ 1 file changed, 94 deletions(-) delete mode 100644 Dockerfile.kda-nightly diff --git a/Dockerfile.kda-nightly b/Dockerfile.kda-nightly deleted file mode 100644 index 582ac830c..000000000 --- a/Dockerfile.kda-nightly +++ /dev/null @@ -1,94 +0,0 @@ -# syntax=docker/dockerfile:1.7-labs -# FROM nvcr.io/nvidia/pytorch:25.05-py3 -FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 -ENV TORCH_CUDA_ARCH_LIST="9.0" -# ARG TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" -ENV FLASH_ATTENTION_CUDA_ARCHS="90" -ENV PIP_BREAK_SYSTEM_PACKAGES=1 - -# Basic system deps -RUN apt-get update && apt-get install --no-install-recommends -y \ - python3 python3-pip python3-dev git-lfs build-essential acl \ - libjpeg-dev zlib1g-dev libpng-dev libtiff5-dev \ - && rm -rf /var/lib/apt/lists/* \ - && git lfs install - -RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 - -# Set the working directory. -WORKDIR /app -# Set the permission to 777 for all files and directories in `/app`, `/home` and python install directories: -# 1. Create directories explicitly because docker use the wrong permission for explicit creation. -# 2. For the rest, set the default ACL to 777 for all users. -RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/tools \ - && setfacl -m d:u::rwx,d:g::rwx,d:o::rwx,u::rwx,g::rwx,o::rwx \ - /app \ - /home \ - /usr \ - /usr/local \ - /usr/local/bin \ - /usr/local/lib - -# The base image enforces versions for things like pytest for no good reason. -ENV PIP_CONSTRAINT="" - -RUN python -m pip install --no-cache-dir --pre \ - torch torchvision torchaudio \ - --index-url https://download.pytorch.org/whl/nightly/cu128 - -RUN python -c "import torch, sys; print('PYTHON:', sys.executable); print('TORCH VERSION:', torch.__version__); print('TORCH FILE:', torch.__file__); print('CUDA:', torch.version.cuda)" - -RUN python -m pip install --no-cache-dir packaging \ - && python -m pip install --no-build-isolation --no-cache-dir git+https://github.com/NVIDIA/apex.git - -# Install flash-linear-attention prerequisites and build from source. -RUN python -m pip install --no-cache-dir einops ninja datasets transformers numpy \ - && (python -m pip uninstall -y flash-linear-attention || true) \ - && python -m pip install -U --no-use-pep517 --no-deps --no-cache-dir \ - git+https://github.com/fla-org/flash-linear-attention - -# Optional toolchain pieces for flash-attention. -RUN python -m pip install packaging psutil ninja -RUN python -m pip install --no-cache-dir flit-core packaging - -RUN python -m pip uninstall -y causal-conv1d || true -RUN MAX_JOBS=2 python -m pip install --no-build-isolation --no-cache-dir "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=2 python -m pip install --no-build-isolation --no-cache-dir --no-deps \ - "mamba_ssm@git+https://github.com/state-spaces/mamba@4a8a2a2" - -# Copy dependency files with universal write permissions for all users. -COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ -COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ -COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ -COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ - -RUN python -m pip install --no-cache-dir pybind11 -# Install dependencies within the virtual environment. -RUN python -m pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" - -RUN python -m pip install --no-cache-dir --no-build-isolation \ - -e ".[VISION]" - - -# We only care about H100 (sm_90) -# RUN MAX_JOBS=1 python -m pip install --no-deps --no-cache-dir --no-build-isolation flash-attn \ -# && python -m pip install pytest -RUN MAX_JOBS=1 python -m pip install pytest - -RUN git clone https://github.com/NVIDIA/apex /tmp/apex \ - && cd /tmp/apex \ - && MAX_JOBS=1 APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation . \ - && rm -rf /tmp/apex - -# Copy the remaining source code with universal write permissions. -COPY --chmod=777 ./Megatron-LM Megatron-LM -COPY --chmod=777 ./examples examples -COPY --chmod=777 ./tests tests -COPY --chmod=777 ./tools tools -COPY --chmod=777 ./fast_llm_external_models fast_llm_external_models -COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/ - -# Set a dummy default user so we don't run in root by default. -# The image is still compatible with any user id. -RUN useradd user -USER user From 2ce4c078b7eb6ac6a97b3b599be3828f9519d37b Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 8 Dec 2025 19:36:48 +0000 Subject: [PATCH 43/43] clean --- tests/layers/test_gdn_equivalence.py | 4 ---- tests/layers/test_kda_equivalence.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/tests/layers/test_gdn_equivalence.py b/tests/layers/test_gdn_equivalence.py index f4d116349..4af68ea16 100644 --- a/tests/layers/test_gdn_equivalence.py +++ b/tests/layers/test_gdn_equivalence.py @@ -119,7 +119,3 @@ def test_fast_llm_gdn_matches_qwen3_next_forward(): fast_out, _ = fast_layer(hidden_states, fast_kwargs) torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/layers/test_kda_equivalence.py b/tests/layers/test_kda_equivalence.py index feb47a511..8745236d4 100644 --- a/tests/layers/test_kda_equivalence.py +++ b/tests/layers/test_kda_equivalence.py @@ -136,7 +136,3 @@ def test_fast_llm_kda_matches_apriel_forward(): fast_out, _ = fast_layer(hidden_states, fast_kwargs) torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) - - -if __name__ == "__main__": - pytest.main([__file__])