diff --git a/.github/workflows/manual-build.yml b/.github/workflows/manual-build.yml index 8240087a2..2d7eb315c 100644 --- a/.github/workflows/manual-build.yml +++ b/.github/workflows/manual-build.yml @@ -34,12 +34,12 @@ jobs: sudo rm -rf /usr/share/dotnet || true sudo rm -rf /opt/ghc || true sudo rm -rf /usr/local/.ghcup || true - + - name: Checkout repository uses: actions/checkout@v4 with: ref: ${{ inputs.commit_sha != '' && inputs.commit_sha || inputs.branch }} - + - name: Get commit info id: commit_info run: | @@ -48,7 +48,7 @@ jobs: echo "full_sha=${COMMIT_SHA}" >> $GITHUB_OUTPUT echo "short_sha=${COMMIT_SHORT}" >> $GITHUB_OUTPUT echo "Building from commit: ${COMMIT_SHA}" - + - name: Docker meta id: meta uses: docker/metadata-action@v5 @@ -59,10 +59,10 @@ jobs: type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }} type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}-${{ steps.commit_info.outputs.short_sha }} type=raw,value=latest-${{ inputs.tag_suffix }},enable=${{ inputs.branch == 'main' && inputs.commit_sha == '' }} - + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - + - name: Login to GHCR if: ${{ inputs.push_image }} uses: docker/login-action@v3 @@ -70,7 +70,7 @@ jobs: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - + - name: Build and push uses: docker/build-push-action@v6 with: @@ -80,7 +80,7 @@ jobs: labels: ${{ steps.meta.outputs.labels }} cache-from: type=registry,ref=ghcr.io/servicenow/fast-llm:cache cache-to: type=registry,ref=ghcr.io/servicenow/fast-llm:cache,mode=max - + - name: Output build info run: | echo "Built Docker image with tags:" diff --git a/Dockerfile b/Dockerfile index 6bc900ae7..5804d0e47 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1.7-labs -FROM nvcr.io/nvidia/pytorch:25.05-py3 +FROM nvcr.io/nvidia/pytorch:25.11-py3 # Install dependencies. RUN apt-get update \ @@ -29,8 +29,9 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) -RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" +RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d @ git+https://github.com/Dao-AILab/causal-conv1d@v1.5.4" +RUN MAX_JOBS=2 pip install --no-build-isolation mamba-ssm==2.2.6.post3 +RUN MAX_JOBS=2 pip install --no-build-isolation "flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@67eee20c8503cd19eeb52aa1b99821308e9260c5" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ @@ -38,7 +39,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.5.1 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index dd6276bf8..4cfc3b61d 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -42,6 +42,7 @@ class ActivationType(enum.StrEnum): gelu = "gelu" silu = "silu" relu = "relu" + sigmoid = "sigmoid" squared_relu = "squared_relu" identity = "identity" @@ -70,6 +71,7 @@ def _set_activation_fn_map() -> None: ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, + ActivationType.sigmoid: torch.nn.functional.sigmoid, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), ActivationType.identity: lambda x: x, } @@ -83,6 +85,7 @@ def _set_activation_fn_map() -> None: ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", ActivationType.identity: "identity", + ActivationType.sigmoid: "sigmoid", } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} _ACTIVATION_HF_NAMES_INV["gelu"] = ActivationType.gelu diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 4ecb7a3be..4b8edaebe 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -5,6 +5,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import ParameterConfig, combine_lr_scales +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -133,6 +134,12 @@ def module_class(self): class GatedRMSNormalizationConfig(RMSNormalizationConfig): _abstract = False + activation: ActivationType = Field( + default=ActivationType.silu, + desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + hint=FieldHint.core, + ) + @property def module_class(self): from fast_llm.layers.common.normalization.normalization import GatedRMSNormalization diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 651d8e4b1..b1e875707 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -1,7 +1,6 @@ import abc import torch -import torch.nn.functional as F from fast_llm.config import Configurable from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ @@ -324,7 +323,7 @@ def _forward_fla(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor gate, self.weight, None, - activation="silu", + activation=self._config.activation.hf_name, eps=self._config.epsilon, residual=None, prenorm=False, @@ -333,4 +332,4 @@ def _forward_fla(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor def _forward_local(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: normalized = self._forward(input_) - return normalized * F.silu(gate) + return normalized * self._config.activation.activation_fn(gate) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 6f36321ec..450591216 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -4,7 +4,6 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig -from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig @@ -15,6 +14,8 @@ import torch from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + from fast_llm.layers.ssm.gdn import GatedDeltaNet + from fast_llm.layers.ssm.kda import KimiDeltaAttention from fast_llm.layers.ssm.mamba import Mamba from fast_llm.tensor import ParameterMeta @@ -84,36 +85,104 @@ class GatedDeltaNetConfig(MixerConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - norm_epsilon: float = Field( - default=1e-6, - desc="Epsilon used by the gated RMS norm.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - activation: ActivationType = Field( - default=ActivationType.silu, - desc="Activation used after the convolution.", - hint=FieldHint.architecture, - ) def _validate(self) -> None: super()._validate() Assert.multiple(self.value_heads, self.key_heads) @property - def layer_class(self) -> "type": + def layer_class(self) -> "type[GatedDeltaNet]": from fast_llm.layers.ssm.gdn import GatedDeltaNet return GatedDeltaNet + def _validate(self) -> None: + super()._validate() + + +@config_class(dynamic_type={MixerConfig: "kda"}) +class KimiDeltaAttentionConfig(MixerConfig): + """ + Configuration for the KimiDeltaAttention mixer inspired by the Kimi Linear models. + """ + + _abstract = False + normalization: GatedRMSNormalizationConfig = Field( + desc="Configuration for the gated normalization applied to the KDA output.", + hint=FieldHint.architecture, + ) + q_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces query vectors.", + hint=FieldHint.architecture, + ) + k_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces key vectors.", + hint=FieldHint.architecture, + ) + v_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces value vectors.", + hint=FieldHint.architecture, + ) + f_a_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the Delta gating pre-activation.", + hint=FieldHint.architecture, + ) + f_b_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the Delta gating expansion.", + hint=FieldHint.architecture, + ) + g_a_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the output gating pre-activation.", + hint=FieldHint.architecture, + ) + g_b_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the output gating expansion.", + hint=FieldHint.architecture, + ) + beta_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces the Beta gate.", + hint=FieldHint.architecture, + ) + output_projection_layer: AffineLinearConfig = Field( + desc="Projection applied after the Delta recurrence and gated normalization.", + hint=FieldHint.architecture, + ) + convolution_layer: CausalConv1dConfig = Field( + desc="Depth-wise convolution applied independently on each Q, K and V stream.", + hint=FieldHint.architecture, + ) + dt_bias_weight: ParameterConfig = Field( + desc="Parameter configuration for the Delta gate bias.", + hint=FieldHint.architecture, + ) + a_log_weight: ParameterConfig = Field( + desc="Parameter configuration for the decay rates.", + hint=FieldHint.architecture, + ) + + heads: int = Field( + default=16, + desc="Number of attention heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + head_dim: int = Field( + default=64, + desc="Dimension of each head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + + @property + def layer_class(self) -> "type[KimiDeltaAttention]": + from fast_llm.layers.ssm.kda import KimiDeltaAttention + + return KimiDeltaAttention + def _validate(self) -> None: with self._set_implicit_default(): - if "epsilon" not in self.normalization._explicit_fields: - self.normalization.epsilon = 1.0e-5 - if "activation" not in self.convolution_layer._explicit_fields: - self.convolution_layer.activation = "silu" - if "kernel_size" not in self.convolution_layer._explicit_fields: - self.convolution_layer.kernel_size = 4 + if "activation" not in self.normalization._explicit_fields: + self.normalization.activation = "sigmoid" super()._validate() diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 9f3a55263..40f15837c 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -9,6 +9,7 @@ from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -204,7 +205,7 @@ def __init__( self.convolution = self._config.convolution_layer.get_layer( qkv_channels_dim, default_add_bias=False, - default_activation=self._config.activation, + default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, ) diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py new file mode 100644 index 000000000..323e1ad13 --- /dev/null +++ b/fast_llm/layers/ssm/kda.py @@ -0,0 +1,346 @@ +import logging +import typing + +import torch +from einops import rearrange, repeat + +from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta + +logger = logging.getLogger(__name__) + +try: + from fla.ops.kda import chunk_kda + from fla.ops.kda.gate import fused_kda_gate +except ImportError: + chunk_kda = None + fused_kda_gate = None + + +def index_first_axis(x, indices): + other_shape = x.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(x, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + +class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): + """ + Implementation of the Kimi Delta Attention mixer. + Reference Implementation: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/modeling_kimi.py + """ + + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias + ) + if chunk_kda is None or fused_kda_gate is None: + raise ImportError( + "KimiDeltaAttention requires the `fla-core` package. " + "Please install it with `pip install -U fla-core`." + ) + + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._heads_dim = TensorDim( + "kda_heads", self._config.heads, self._parallel_dim if self._config.heads > 1 else None + ) + self._head_dim = TensorDim("kda_head_dim", self._config.head_dim) + self._projection_dim = CompositeTensorDim("kda_projection", (self._heads_dim, self._head_dim)) + self._local_heads = self._heads_dim.size + self._projection_size = self._projection_dim.size + + init = init_normal_(std=self._hidden_size**-0.5) + self.q_proj = self._config.q_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.k_proj = self._config.k_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.v_proj = self._config.v_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.q_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=ActivationType.silu, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.k_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=ActivationType.silu, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.v_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=ActivationType.silu, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.f_a_proj = self._config.f_a_projection_layer.get_layer( + hidden_dim, + self._head_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=False, # self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.f_b_proj = self._config.f_b_projection_layer.get_layer( + self._head_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.g_a_proj = self._config.g_a_projection_layer.get_layer( + hidden_dim, + self._head_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=False, # self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.g_b_proj = self._config.g_b_projection_layer.get_layer( + self._head_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.beta_proj = self._config.beta_projection_layer.get_layer( + hidden_dim, + self._heads_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.o_proj = self._config.output_projection_layer.get_layer( + self._projection_dim, + hidden_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( + (self._projection_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.A_log: ParameterMeta = self._config.a_log_weight.get_parameter( + (self._heads_dim,), + default_initialization=LambdaInitializer( + lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_() + ), + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.norm = self._config.normalization.get_layer( + self._head_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module, seq_idx: torch.Tensor = None) -> torch.Tensor: + """ + Applies convolution. + Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one just uses causal_conv1d anyways. + Varlen: + - seq. idx are only suppored in channel last layout, i.e. no transpose + """ + tensor = rearrange(tensor, "b t d -> b d t") + # tensor = tensor.transpose(1, 2).contiguous() if seq_idx is None else tensor.transpose(1, 2) + tensor = conv(tensor, seq_idx=seq_idx) + return tensor.transpose(1, 2).contiguous() + + def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: + tensor = tensor.contiguous() + # since head_dim is the same vor k,q and v + # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Same as in gdn, the idea is to always do forward pass in a packed way, whcih is required for varlen support. + """ + sequence_first = kwargs[BlockKwargs.sequence_first] + hidden_states = input_ + + cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) + seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) + # TODO: can be made more efficeint by rearranging hidden states directly and only once + residual_dtype = hidden_states.dtype + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if sequence_first: + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + batch_size, sequence_length, _ = q.size() + q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) + k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) + v = rearrange(v, "b s ... -> (b s) ...").unsqueeze(0) + + # because we use cu_seqlens, chunk_kda requires batch size to be 1 (flatten, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303) + # similarly to ShortConvolution from fla we already operate on flattened batches here (https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914) + q = self._apply_conv(q, self.q_conv, seq_idx=seq_idx) + k = self._apply_conv(k, self.k_conv, seq_idx=seq_idx) + v = self._apply_conv(v, self.v_conv, seq_idx=seq_idx) + + g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) + if sequence_first: + g_kernel = g_kernel.transpose(0, 1) + g_kernel = self._reshape_heads(g_kernel) + g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) + + g_kernel = fused_kda_gate(g_kernel, self.A_log.float(), dt_bias=self.dt_bias) + + beta = torch.sigmoid(self.beta_proj(hidden_states).float()) + q = self._reshape_heads(q) + k = self._reshape_heads(k) + v = self._reshape_heads(v) + if sequence_first: + beta = beta.transpose(0, 1) + beta = rearrange(beta, "b s h -> (b s) h").unsqueeze(0) + + # need to install nightly triton for this to work on H100, see https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md + # cu_seqlens requires batch ssize to be 1, i.e. flattened bacthes + attn_out, _ = chunk_kda( + q=q, + k=k, + v=v, + g=g_kernel, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + + attn_out = attn_out.to(residual_dtype) + + g_out = self.g_b_proj(self.g_a_proj(hidden_states)) # bs x seq x n_local_heads x head dim + g_out = self._reshape_heads(g_out) + if sequence_first: + g_out = g_out.transpose(0, 1) + + attn_out = rearrange(attn_out.squeeze(0), "(b s) h d -> b s h d", b=batch_size, s=sequence_length) + attn_out = self.norm(attn_out, g_out) + attn_out = rearrange(attn_out, "b s h d -> b s (h d)") + if sequence_first: + attn_out = attn_out.transpose(0, 1) + attn_out = self.o_proj(attn_out) + + return attn_out + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + raise NotImplementedError() + + def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: + sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + device = kwargs.get("device", None) + if sequence_lengths is None: + raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") + + seqlens = torch.tensor( + [ + 0, + *( + sequence_length + for sequence_lengths in kwargs[LinearAttentionKwargs.sequence_lengths] + for sequence_length in sequence_lengths # bs + ), + ], + dtype=torch.int32, + ) + cu_seqlens = seqlens.cumsum_(0).to(device) + # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 + # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 + kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens + # seq_idx has to be (bs, seqlen), but bs is forced to 1 + kwargs[LinearAttentionKwargs.seq_idx] = ( + ( + torch.cat( + [ + torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) + for n in (torch.diff(cu_seqlens).to(torch.int32)) + ], + dim=0, + ) + .eq(0) + .cumsum(0) + - 1 + ) + .to(torch.int32) + .unsqueeze(0) + ) + + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + self._preprocess_for_varlen(kwargs) diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index f7c34a973..e8160ec87 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -3,12 +3,14 @@ """ from fast_llm.layers.attention.config import AttentionConfig # isort: skip -from fast_llm.layers.ssm.config import ( - MambaConfig, - Mamba2Config, +from fast_llm.layers.ssm.config import ( # isort: skip DiscreteMamba2Config, GatedDeltaNetConfig, -) # isort: skip + KimiDeltaAttentionConfig, + Mamba2Config, + MambaConfig, +) + from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.models.multimodal.config import MultiModalModelConfig, MultiModalTrainerConfig # isort: skip diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index c93e2e966..359adfd9d 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,7 +8,12 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.ssm.config import DiscreteMamba2Config, GatedDeltaNetConfig, Mamba2Config +from fast_llm.layers.ssm.config import ( + DiscreteMamba2Config, + GatedDeltaNetConfig, + KimiDeltaAttentionConfig, + Mamba2Config, +) from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters @@ -302,11 +307,138 @@ def get_converters( ] +class KimiDeltaAttentionConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "type": "kda", + "head_dim": config["linear_attn_config"]["head_dim"], + "heads": config["linear_attn_config"]["num_heads"], + "convolution_layer": { + "kernel_size": config["linear_attn_config"]["short_conv_kernel_size"], + }, + } + + @classmethod + def export_config(cls, config: KimiDeltaAttentionConfig) -> dict: + return { + "linear_attn_config": { + "head_dim": config.head_dim, + "num_heads": config.heads, + "short_conv_kernel_size": config.convolution_layer.kernel_size, + }, + } + + @classmethod + def get_converters( + cls, + config: KimiDeltaAttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.q_proj", + f"{hf_prefix}.q_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.k_proj", + f"{hf_prefix}.k_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.v_proj", + f"{hf_prefix}.v_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.q_conv", + f"{hf_prefix}.q_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.k_conv", + f"{hf_prefix}.k_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.v_conv", + f"{hf_prefix}.v_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.f_a_proj", + f"{hf_prefix}.f_a_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.f_b_proj", + f"{hf_prefix}.f_b_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.g_a_proj", + f"{hf_prefix}.g_a_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.g_b_proj", + f"{hf_prefix}.g_b_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.beta_proj", + f"{hf_prefix}.beta_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.o_proj", + f"{hf_prefix}.o_proj", + False, + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.A_log", + f"{hf_prefix}.A_log", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.dt_bias", + f"{hf_prefix}.dt_bias", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.norm", + f"{hf_prefix}.norm", + False, + drop_on_export=drop_on_export, + ), + ] + + class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" +class AprielKimiDeltaAttentionBlockConverter(MistralBlockConverter): + mixer_converter_class: typing.ClassVar[type[KimiDeltaAttentionConverter]] = KimiDeltaAttentionConverter + hf_mixer_name: typing.ClassVar[str] = "mixer" + + class AprielMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" @@ -327,6 +459,7 @@ class AprielBlockConverter: _converter_classes = { AttentionConfig: MistralBlockConverter, Mamba2Config: AprielMamba2BlockConverter, + KimiDeltaAttentionConfig: AprielKimiDeltaAttentionBlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, GatedDeltaNetConfig: AprielGatedDeltaNetBlockConverter, } diff --git a/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py index 12ee343ef..7f61fcda5 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py @@ -28,10 +28,37 @@ } +class AprielGDNConfig: + def __init__( + self, + linear_num_key_heads=16, + linear_num_value_heads=32, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + kl_short_conv_kernel_size=4, + kl_num_heads=32, + kl_head_dim=128, + ): + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_conv_kernel_dim = linear_conv_kernel_dim + + # Kimi LInear + self.short_conv_kernel_size = kl_short_conv_kernel_size + self.head_dim = kl_head_dim + self.num_heads = kl_num_heads + + +LAYER_TYPES = {"t": "full_attention", "swa": "sliding_attention", "gdn": "gated_delta_net", "kl": "kimi_linear"} + + class AprielHybridSSMConfig(MistralConfig): model_type = "apriel_hybrid_ssm" - def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + def __init__(self, hybrid_block_layout=["t"], ssm_cfg=None, gdn_cfg=None, **kwargs): super().__init__(**kwargs) self.hybrid_block_layout = hybrid_block_layout self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 @@ -40,3 +67,14 @@ def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): for k, v in ssm_config_default.items(): if k not in self.ssm_cfg: self.ssm_cfg[k] = v # to make sure all elements are present in the config + + gdn_config: AprielGDNConfig = ( + AprielGDNConfig(**gdn_cfg) if isinstance(gdn_cfg, dict) else gdn_cfg or AprielGDNConfig() + ) + # for compatibility with vllm + self.layer_types = [LAYER_TYPES[lt] for lt in hybrid_block_layout] # this is for vllm compatibility + self.linear_attn_config = { + "short_conv_kernel_size": gdn_config.short_conv_kernel_size, + "head_dim": gdn_config.head_dim, + "num_heads": gdn_config.num_heads, + } diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index e63584433..1aca40da0 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -23,9 +23,14 @@ from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig -# from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn as varlen_selective_scan_fn -# from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as varlen_causal_conv1d_fn - +try: + from fla.modules import FusedRMSNormGated, ShortConvolution + from fla.ops.kda import chunk_kda, fused_recurrent_kda + from fla.ops.kda.gate import fused_kda_gate + from fla.ops.utils.index import prepare_cu_seqlens_from_mask, prepare_lens_from_mask + from fla.utils import tensor_cache +except ImportError: + raise ImportError("Plese run `pip install -U fla-core`") logger = logging.get_logger(__name__) @@ -45,6 +50,241 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def index_first_axis(x, indices): + other_shape = x.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(x, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + +def index_put_first_axis(x, indices, first_axis_dim): + y = torch.zeros(first_axis_dim, *x.shape[1:], device=x.device, dtype=x.dtype) + # TODO [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + y[indices] = x + # y.scatter_(0, repeat(indices, 'z -> z d', d=x.shape[1]), x) + return y + + +@tensor_cache +def get_unpad_data( + attention_mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, int]: + lens = prepare_lens_from_mask(attention_mask) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = lens.max().item() + cu_seqlens = prepare_cu_seqlens_from_mask(attention_mask) + return indices, cu_seqlens, max_seqlen_in_batch + + +def unpad_input( + q: torch.Tensor, + states: tuple[torch.Tensor], + attention_mask: torch.Tensor, + q_len: int, + keepdim: bool = False, +): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) + batch_size, seq_len, *_ = states[0].shape + + state = tuple(index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k) for s in states) + + if q_len == seq_len: + q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif q_len == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) + indices_q = cu_seqlens_q[:-1] + q = q.squeeze(1) + else: + raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)") + + if keepdim: + q = q.unsqueeze(0) + state = tuple(s.unsqueeze(0) for s in state) + + return ( + q, + state, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def pad_input( + hidden_states: torch.Tensor, + indices: torch.LongTensor, + batch_size: int, + seq_len: int, +) -> torch.Tensor: + output = index_put_first_axis(hidden_states, indices, batch_size * seq_len) + return rearrange(output, "(b s) ... -> b s ...", b=batch_size) + + +class KimiDeltaAttention(nn.Module): + def __init__(self, config: AprielHybridSSMConfig, layer_idx: int): + super().__init__() + self.config = config + self.mode = "chunk" + + self.hidden_size = config.hidden_size + self.conv_size = config.short_conv_kernel_size + self.head_dim = config.head_dim + self.num_heads = config.num_heads + self.head_k_dim = self.head_dim + self.num_k_heads = self.num_heads + + self.layer_idx = layer_idx + + assert self.mode in ["chunk", "fused_recurrent"], f"Not suppoerted mode `{self.mode}`." + + projection_k_size = self.head_k_dim * self.num_k_heads + projection_size = self.head_dim * self.num_heads + + self.q_proj = nn.Linear(self.hidden_size, projection_k_size, bias=False) + self.k_proj = nn.Linear(self.hidden_size, projection_k_size, bias=False) + self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False) + + self.q_conv1d = ShortConvolution( + hidden_size=projection_k_size, + kernel_size=self.conv_size, + activation="silu", + ) + self.k_conv1d = ShortConvolution( + hidden_size=projection_k_size, + kernel_size=self.conv_size, + activation="silu", + ) + self.v_conv1d = ShortConvolution( + hidden_size=projection_size, + kernel_size=self.conv_size, + activation="silu", + ) + + self.A_log = torch.nn.Parameter( + torch.log(torch.empty(self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1) + ) + + self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False) + self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False) + + self.dt_bias = nn.Parameter(torch.empty(projection_size, dtype=torch.float32)) + + self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False) + + self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False) + self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False) + + self.o_norm = FusedRMSNormGated(self.head_dim, eps=config.rms_norm_eps, activation="sigmoid") + self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + cache_params=None, + **kwargs: Unpack[dict], + ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: + if attention_mask is not None: + if attention_mask.dim() != 2: + attention_mask = kwargs.get("padding_mask") + + if attention_mask is not None and attention_mask.dim() != 2: + raise ValueError( + "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] " + "(0 = padding). 3D masks are not supported here.", + ) + use_cache = cache_params is not None + batch_size, q_len, _ = hidden_states.shape + mode = "fused_recurrent" if q_len <= 64 else self.mode + if self.training: + assert mode == "chunk", "Only chunk mode is supported in training." + + cu_seqlens = kwargs.get("cu_seqlens") + indices = None + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + conv_state_q, conv_state_k, conv_state_v = None, None, None + recurrent_state = None + if cache_params is not None: + if cache_params.conv_states[self.layer_idx] is not None: + conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + g = self.f_b_proj(self.f_a_proj(hidden_states)) + g = rearrange(g, "... (h d) -> ... h d", d=self.head_dim) + + g = fused_kda_gate(g, self.A_log.float().squeeze(), dt_bias=self.dt_bias) + beta = self.b_proj(hidden_states).float().sigmoid() + + q, k = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_k_dim), (q, k)) + v = rearrange(v, "... (h d) -> ... h d", d=self.head_dim) + + if mode == "chunk": + o, recurrent_state = chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + else: + o, recurrent_state = fused_recurrent_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = recurrent_state + cache_params.conv_states[self.layer_idx] = (conv_state_q, conv_state_k, conv_state_v) + + g = self.g_b_proj(self.g_a_proj(hidden_states)) + g = rearrange(g, "... (h d) -> ... h d", d=self.head_dim) + o = self.o_norm(o, g) + + o = rearrange(o, "b t h d -> b t (h d)") + o = self.o_proj(o) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o + + class HybridMambaAttentionStaticCache(Cache): def __init__(self, config: AprielHybridSSMConfig, batch_size, max_length, dtype=torch.float16, device=None): super().__init__() # config, batch_size, max_length, device, dtype) diff --git a/setup.cfg b/setup.cfg index f4b2c904b..58f8ea2d1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,10 +25,10 @@ CORE = # Used for checkpoints safetensors>=0.5.3 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation - flash-attn==2.7.3 - # Dropless MLP is broken with triton 3.2.0, 3.3.0 and 3.3.1. TODO: Remove once a working triton version is released. - # TODO: Removed because it breaks cpu-only installs and pip dependency resolution. - # triton==3.1.0 + flash-attn==2.7.4.post1 + # Dropless MoE kernel is broken with triton >= 3.2.0 and needs a rewrite (also limited to 32 experts). + # Not pinning triton here as it breaks cpu-only installs and pip dependency resolution. + # triton==3.5.1 # Small packages required for some optional features and tools. @@ -52,8 +52,8 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d]==2.2.4 - flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main + mamba_ssm[causal-conv1d]==2.2.6.post3 + flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@67eee20c8503cd19eeb52aa1b99821308e9260c5 GENERATION = lm_eval>=0.4.9 diff --git a/tests/layers/test_gdn_equivalence.py b/tests/layers/test_gdn_equivalence.py index 9886056ea..4af68ea16 100644 --- a/tests/layers/test_gdn_equivalence.py +++ b/tests/layers/test_gdn_equivalence.py @@ -1,104 +1,89 @@ import pytest import torch -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.functional.config import ActivationType +from fast_llm.config import UpdateType from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.ssm.config import GatedDeltaNetConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet +from tests.utils.utils import get_base_model, get_stage, requires_cuda -try: - from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextConfig, Qwen3NextGatedDeltaNet -except ImportError: - Qwen3NextConfig, Qwen3NextGatedDeltaNet = None, None - - -def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: - """ - Instantiate meta-allocated parameters on the requested device so the layer can run standalone. - """ - for name, param in module.named_parameters(): - if param.device.type != "meta": - continue - param_data = torch.empty_like(param, device=device) - param.init_parameter(param_data, distributed) - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - target = module - if module_path is not None: - for part in module_path.split("."): - target = getattr(target, part) - new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - new_param.grad = None - new_param.grad_buffer = torch.zeros_like(param_data) - new_param.param_grad_is_zero = True - target._parameters[param_name] = new_param +VOCAB_SIZE = 500 +HIDDEN_SIZE = 16 +SEQ_LEN = 65 +NUM_V_HEADS = 4 +NUM_K_HEADS = 2 +HEAD_DIM = 4 +KERNEL_SIZE = 4 @pytest.mark.slow -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Varlen test needs CUDA") -@pytest.mark.skipif(Qwen3NextConfig is None, reason="transformers with Qwen3-Next not installed") +@requires_cuda def test_fast_llm_gdn_matches_qwen3_next_forward(): torch.manual_seed(0) device = torch.device("cuda") dtype = torch.bfloat16 - hidden_size = 16 - seq_len = 6 - num_k_heads = 2 - num_v_heads = 4 - head_k_dim = 4 - head_v_dim = 4 - kernel_size = 4 + config_dict_hf = { + "num_value_heads": NUM_V_HEADS, + "num_key_heads": NUM_K_HEADS, + "key_head_dim": HEAD_DIM, + "value_head_dim": HEAD_DIM, + "conv_kernel_size": KERNEL_SIZE, + "activation": "silu", + "norm_eps": 1e-5, + } - hf_config = Qwen3NextConfig( - hidden_size=hidden_size, - linear_num_key_heads=num_k_heads, - linear_num_value_heads=num_v_heads, - linear_key_head_dim=head_k_dim, - linear_value_head_dim=head_v_dim, - linear_conv_kernel_dim=kernel_size, - hidden_act="silu", - rms_norm_eps=1e-6, - dtype=dtype, + hf_layer = ( + Apriel2GatedDeltaNet(HIDDEN_SIZE, config_dict_hf, layer_idx=0, dtype=dtype) + .to(device=device, dtype=dtype) + .eval() ) - hf_layer = Qwen3NextGatedDeltaNet(hf_config, layer_idx=0).to(device=device, dtype=dtype).eval() - fast_config = GatedDeltaNetConfig( - value_heads=num_v_heads, - key_heads=num_k_heads, - value_head_dim=head_v_dim, - key_head_dim=head_k_dim, - activation=ActivationType.silu, - normalization={"epsilon": hf_config.rms_norm_eps}, - convolution_layer={"kernel_size": kernel_size, "activation": ActivationType.silu}, + config = GPTBaseModelConfig.from_dict( + { + "decoder": { + "num_blocks": 1, + "block": { + "mixer": { + "type": "gdn", + "value_heads": NUM_V_HEADS, + "key_heads": NUM_K_HEADS, + "key_head_dim": HEAD_DIM, + "value_head_dim": HEAD_DIM, + "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, + } + }, + }, + "embeddings": {"vocab_size": VOCAB_SIZE}, + "hidden_size": HIDDEN_SIZE, + }, + update_type=UpdateType.update, ) - distributed_config = DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, + + model, distributed = get_base_model( + GPTModelConfig.from_dict( + { + "base_model": config, + "distributed": {}, + }, + ) ) - hidden_dim = TensorDim("hidden", hidden_size) - fast_layer = fast_config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - distributed = Distributed(config=distributed_config) - fast_layer.setup(distributed) - _materialize_mixer_tensors(fast_layer, distributed, device) + fast_layer = model.decoder[0].mixer + get_stage([fast_layer], distributed, [], {}) fast_layer.to(device=device, dtype=dtype).eval() with torch.no_grad(): - fast_layer.in_proj_qkvz.weight.copy_(hf_layer.in_proj_qkvz.weight) - fast_layer.in_proj_ba.weight.copy_(hf_layer.in_proj_ba.weight) - fast_layer.convolution.weight.copy_(hf_layer.conv1d.weight) - if fast_layer.convolution.bias is not None and hf_layer.conv1d.bias is not None: - fast_layer.convolution.bias.copy_(hf_layer.conv1d.bias) - fast_layer.out_proj.weight.copy_(hf_layer.out_proj.weight) - fast_layer.A_log.copy_(hf_layer.A_log) - fast_layer.dt_bias.copy_(hf_layer.dt_bias) - fast_layer.norm.weight.copy_(hf_layer.norm.weight) - - hidden_states = torch.randn(1, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=False) + fast_layer.in_proj_qkvz.weight.copy_(hf_layer.gdn.in_proj_qkvz.weight) + fast_layer.in_proj_ba.weight.copy_(hf_layer.gdn.in_proj_ba.weight) + fast_layer.convolution.weight.copy_(hf_layer.gdn.conv1d.weight) + if fast_layer.convolution.bias is not None and hf_layer.gdn.conv1d.bias is not None: + fast_layer.convolution.bias.copy_(hf_layer.gdn.conv1d.bias) + fast_layer.out_proj.weight.copy_(hf_layer.gdn.out_proj.weight) + fast_layer.A_log.copy_(hf_layer.gdn.A_log) + fast_layer.dt_bias.copy_(hf_layer.gdn.dt_bias) + fast_layer.norm.weight.copy_(hf_layer.gdn.norm.weight) + + hidden_states = torch.randn(1, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) param_map = { "in_proj_qkvz.weight": "in_proj_qkvz.weight", @@ -110,25 +95,27 @@ def test_fast_llm_gdn_matches_qwen3_next_forward(): "dt_bias": "dt_bias", "norm.weight": "norm.weight", } + hf_state_dict = hf_layer.gdn.state_dict() for k, p in fast_layer.state_dict().items(): - torch.testing.assert_close(p, hf_layer.state_dict()[param_map[k]], atol=1e-6, rtol=1e-6) + torch.testing.assert_close(p, hf_state_dict[param_map[k]], atol=1e-5, rtol=1e-5) # need to monkey patch the hf implementation with our fix_query_key_value_ordering due to the layout differences - hf_layer.fix_query_key_value_ordering = fast_layer.fix_query_key_value_ordering + hf_layer.gdn.fix_query_key_value_ordering = fast_layer.fix_query_key_value_ordering hf_layer._local_key_heads = fast_layer._local_key_heads hf_layer._local_value_heads = fast_layer._local_value_heads hf_layer._config = fast_layer._config - hf_out = hf_layer(hidden_states) + hf_out = hf_layer(hidden_states)[0] + sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] fast_kwargs = { + BlockKwargs.device: device, BlockKwargs.sequence_first: False, - BlockKwargs.hidden_dims: (hidden_dim,), + BlockKwargs.hidden_dims: (HIDDEN_SIZE,), + BlockKwargs.sequence_length: SEQ_LEN, + BlockKwargs.sequence_lengths: sequence_lengths, } - fast_out = fast_layer(hidden_states, fast_kwargs) + fast_layer.preprocess(fast_kwargs) + fast_out, _ = fast_layer(hidden_states, fast_kwargs) torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/layers/test_kda_equivalence.py b/tests/layers/test_kda_equivalence.py new file mode 100644 index 000000000..8745236d4 --- /dev/null +++ b/tests/layers/test_kda_equivalence.py @@ -0,0 +1,138 @@ +import pytest +import torch + +import fast_llm.layers.ssm.kda as kda_module +from fast_llm.config import UpdateType +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from tests.utils.utils import get_base_model, get_stage, requires_cuda + +try: + from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig + from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention +except ImportError: + AprielHybridSSMConfig, KimiDeltaAttention = None, None + +VOCAB_SIZE = 500 +HIDDEN_SIZE = 16 +SEQ_LEN = 65 +NUM_HEADS = 4 +HEAD_DIM = 4 +KERNEL_SIZE = 4 + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") +@pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") +def test_fast_llm_kda_matches_apriel_forward(): + torch.manual_seed(0) + device = torch.device("cuda") + dtype = torch.bfloat16 + + hf_config = AprielHybridSSMConfig( + hidden_size=HIDDEN_SIZE, + num_attention_heads=NUM_HEADS, + num_hidden_layers=1, + rms_norm_eps=1e-6, + ) + hf_config.short_conv_kernel_size = KERNEL_SIZE + hf_config.head_dim = HEAD_DIM + hf_config.num_heads = NUM_HEADS + hf_layer = KimiDeltaAttention(hf_config, layer_idx=0).to(device=device, dtype=dtype).eval() + + config = GPTBaseModelConfig.from_dict( + { + "decoder": { + "num_blocks": 1, + "block": { + "mixer": { + "type": "kda", + "heads": NUM_HEADS, + "head_dim": HEAD_DIM, + "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, + "normalization": {"epsilon": hf_config.rms_norm_eps, "activation": "sigmoid"}, + } + }, + }, + "embeddings": {"vocab_size": VOCAB_SIZE}, + "hidden_size": HIDDEN_SIZE, + }, + update_type=UpdateType.update, + ) + + model, distributed = get_base_model( + GPTModelConfig.from_dict( + { + "base_model": config, + "distributed": {}, + }, + ) + ) + fast_layer = model.decoder[0].mixer + get_stage([fast_layer], distributed, [], {}) + fast_layer.to(device=device, dtype=dtype).eval() + + with torch.no_grad(): + fast_layer.q_proj.weight.copy_(hf_layer.q_proj.weight) + fast_layer.k_proj.weight.copy_(hf_layer.k_proj.weight) + fast_layer.v_proj.weight.copy_(hf_layer.v_proj.weight) + fast_layer.q_conv.weight.copy_(hf_layer.q_conv1d.weight) + fast_layer.k_conv.weight.copy_(hf_layer.k_conv1d.weight) + fast_layer.v_conv.weight.copy_(hf_layer.v_conv1d.weight) + if fast_layer.q_conv.bias is not None and hf_layer.q_conv1d.bias is not None: + fast_layer.q_conv.bias.copy_(hf_layer.q_conv1d.bias) + if fast_layer.k_conv.bias is not None and hf_layer.k_conv1d.bias is not None: + fast_layer.k_conv.bias.copy_(hf_layer.k_conv1d.bias) + if fast_layer.v_conv.bias is not None and hf_layer.v_conv1d.bias is not None: + fast_layer.v_conv.bias.copy_(hf_layer.v_conv1d.bias) + fast_layer.f_a_proj.weight.copy_(hf_layer.f_a_proj.weight) + fast_layer.f_b_proj.weight.copy_(hf_layer.f_b_proj.weight) + fast_layer.g_a_proj.weight.copy_(hf_layer.g_a_proj.weight) + fast_layer.g_b_proj.weight.copy_(hf_layer.g_b_proj.weight) + fast_layer.beta_proj.weight.copy_(hf_layer.b_proj.weight) + fast_layer.o_proj.weight.copy_(hf_layer.o_proj.weight) + fast_layer.A_log.copy_(hf_layer.A_log.reshape_as(fast_layer.A_log)) + fast_layer.dt_bias.copy_(hf_layer.dt_bias.reshape_as(fast_layer.dt_bias)) + fast_layer.norm.weight.copy_(hf_layer.o_norm.weight) + + param_map = { + "q_proj.weight": "q_proj.weight", + "k_proj.weight": "k_proj.weight", + "v_proj.weight": "v_proj.weight", + "q_conv.weight": "q_conv1d.weight", + "k_conv.weight": "k_conv1d.weight", + "v_conv.weight": "v_conv1d.weight", + "f_a_proj.weight": "f_a_proj.weight", + "f_b_proj.weight": "f_b_proj.weight", + "g_a_proj.weight": "g_a_proj.weight", + "g_b_proj.weight": "g_b_proj.weight", + "beta_proj.weight": "b_proj.weight", + "o_proj.weight": "o_proj.weight", + "A_log": "A_log", + "dt_bias": "dt_bias", + "norm.weight": "o_norm.weight", + } + for fast_name, hf_name in param_map.items(): + fast_param = fast_layer.state_dict()[fast_name] + hf_param = hf_layer.state_dict()[hf_name] + if fast_param.shape != hf_param.shape: + hf_param = hf_param.reshape_as(fast_param) + print(f"Comparing parameter {fast_name} with shape {fast_param.shape}") + torch.testing.assert_close(fast_param, hf_param, atol=1e-5, rtol=1e-5) + + hidden_states = torch.randn(2, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) + hf_layer.training = True + hf_out = hf_layer(hidden_states) + + sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] + fast_kwargs = { + BlockKwargs.device: device, + BlockKwargs.sequence_first: False, + BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.hidden_dims: (HIDDEN_SIZE,), + } + fast_layer.preprocess(fast_kwargs) + fast_out, _ = fast_layer(hidden_states, fast_kwargs) + + torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c431bb26d..6383e6aae 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -39,7 +39,7 @@ def _reverse_kl_loss( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask.flatten()).mean() + loss = (loss_per_sample * loss_mask.flatten()).sum() / loss_mask.sum() return loss diff --git a/tests/test_varlen.py b/tests/test_varlen.py index 126a3e1e5..ed51d93a2 100644 --- a/tests/test_varlen.py +++ b/tests/test_varlen.py @@ -7,7 +7,8 @@ from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.decoder.config import MixerConfig from fast_llm.layers.ssm import gdn as gdn_module -from fast_llm.layers.ssm.config import GatedDeltaNetConfig +from fast_llm.layers.ssm import kda as kda_module +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig @pytest.fixture @@ -123,15 +124,33 @@ def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: # TODO: include mamba varlen @pytest.mark.slow @pytest.mark.skipif(not torch.cuda.is_available(), reason="Varlen test needs CUDA") -@pytest.mark.skipif( - gdn_module.chunk_gated_delta_rule is None, - reason="Gated Delta Net fused kernels not available", -) @pytest.mark.parametrize( "config, sequence_first", [ - pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), False), - pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), True), + pytest.param( + GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), + False, + marks=pytest.mark.skipif( + gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" + ), + ), + pytest.param( + GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), + True, + marks=pytest.mark.skipif( + gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" + ), + ), + pytest.param( + KimiDeltaAttentionConfig(heads=4, head_dim=16), + False, + marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), + ), + pytest.param( + KimiDeltaAttentionConfig(heads=4, head_dim=16), + True, + marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), + ), ], ) def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: bool, distributed_config, distributed): @@ -207,13 +226,12 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): if param.requires_grad: - torch.testing.assert_close( + assert torch.allclose( _param_grad(param), _param_grad(param_ref), - atol=1e-3, + atol=2e-3, rtol=1e-3, - msg=f"Grad mismatch for parameter {name}", - ) + ), f"Grad mismatch for parameter {name}" if __name__ == "__main__": diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 278544d71..428f7522c 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -1020,6 +1020,45 @@ def _update_and_add_testing_config( ) +_update_and_add_testing_config( + # Tests hybrid with KDA mixer. + "llama", + "hybrid_kda", + updates={ + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "t": copy.deepcopy(_llama_block), + "kda": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "kda", + "heads": 4, + "head_dim": 16, + }, + }, + }, + "num_blocks": 2, + "pattern": ["t", "kda"], + }, + }, + megatron_args=None, + checkpoint_format=AprielHybridSSMCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + }, + compare_factor=10.0, # similar to gdn with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla + # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). + # we should be using STP with this model, not TP! + skip_tests=(r"sdp", r"ms", r"^tp2$"), +) + + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: models = request.config.getoption("--models")