Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
1a219c4
wip
oleksost Nov 18, 2025
5242eb6
added gdn
oleksost Nov 18, 2025
bec22de
gdn layer
oleksost Nov 19, 2025
7f79909
kda
oleksost Nov 19, 2025
8636f09
wip
oleksost Nov 19, 2025
a20c958
convertion kda
oleksost Nov 20, 2025
8ac5167
tp and sequence tp
oleksost Nov 21, 2025
f1a51f2
varlen kda
oleksost Nov 22, 2025
3b367d8
gdn only: varlen test
oleksost Nov 24, 2025
c48d4ee
clean up
oleksost Nov 24, 2025
e2bb25c
test config
oleksost Nov 25, 2025
d4f9b85
wip
oleksost Nov 25, 2025
8017a80
gdn tests
oleksost Nov 25, 2025
1e01601
tests
oleksost Nov 25, 2025
ca8cb5c
tests
oleksost Nov 25, 2025
694d287
nvm
oleksost Nov 25, 2025
d3bd916
requirements
oleksost Nov 26, 2025
3ff7799
wip
oleksost Nov 26, 2025
9a53c5b
clean up
oleksost Nov 26, 2025
80041ce
conversion
oleksost Nov 26, 2025
67a234a
Merge branch 'gdn' into kda
oleksost Nov 26, 2025
d6677b0
comments on the layour + HF forward equivalence test
oleksost Nov 26, 2025
a75cd9f
Merge branch 'gdn' into kda
oleksost Nov 26, 2025
5d3b6d0
wip
oleksost Nov 26, 2025
0d41dce
wip
oleksost Nov 26, 2025
6e2c1fe
varlen test
oleksost Nov 26, 2025
8938a1d
varlen test
oleksost Nov 26, 2025
6c2bd46
wip
oleksost Nov 26, 2025
28a6176
wip
oleksost Nov 26, 2025
2a30bac
wip
oleksost Nov 26, 2025
cad93ab
kda equivalence test
oleksost Nov 27, 2025
8f957a4
nightly requirements
oleksost Nov 27, 2025
82c9cc4
docker
oleksost Nov 27, 2025
d25994e
manual build
oleksost Nov 27, 2025
3651b06
Merge branch 'main' into gdn
oleksost Dec 1, 2025
7b31e78
Merge branch 'gdn' into kda
oleksost Dec 1, 2025
5a44097
two docker files
oleksost Dec 1, 2025
a164a2b
test import fix
oleksost Dec 2, 2025
c4aa9b1
set correct activations
oleksost Dec 2, 2025
a8849cb
import
oleksost Dec 2, 2025
5f32ba7
kda docker file
oleksost Dec 5, 2025
1dde2a9
Merge branch 'main' into kda
oleksost Dec 5, 2025
685f351
Merge remote-tracking branch 'origin/main' into kda
oleksost Dec 5, 2025
d33d6d7
revert workflow change
oleksost Dec 5, 2025
05abc03
removed unused requirements file
oleksost Dec 5, 2025
7b30a36
Bump base image and dependencies for KDA support
tscholak Dec 7, 2025
eb52cc7
fixes
oleksost Dec 8, 2025
372771b
removed kda docker since we probably do not need it
oleksost Dec 8, 2025
aea8b1e
Merge branch 'main' into kda
oleksost Dec 8, 2025
407be82
Merge remote-tracking branch 'origin/tscholak/bump-dependencies' into…
oleksost Dec 8, 2025
2ce4c07
clean
oleksost Dec 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/manual-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ jobs:
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf /usr/local/.ghcup || true

- name: Checkout repository
uses: actions/checkout@v4
with:
ref: ${{ inputs.commit_sha != '' && inputs.commit_sha || inputs.branch }}

- name: Get commit info
id: commit_info
run: |
Expand All @@ -48,7 +48,7 @@ jobs:
echo "full_sha=${COMMIT_SHA}" >> $GITHUB_OUTPUT
echo "short_sha=${COMMIT_SHORT}" >> $GITHUB_OUTPUT
echo "Building from commit: ${COMMIT_SHA}"

- name: Docker meta
id: meta
uses: docker/metadata-action@v5
Expand All @@ -59,18 +59,18 @@ jobs:
type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}
type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}-${{ steps.commit_info.outputs.short_sha }}
type=raw,value=latest-${{ inputs.tag_suffix }},enable=${{ inputs.branch == 'main' && inputs.commit_sha == '' }}

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Login to GHCR
if: ${{ inputs.push_image }}
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Build and push
uses: docker/build-push-action@v6
with:
Expand All @@ -80,7 +80,7 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=ghcr.io/servicenow/fast-llm:cache
cache-to: type=registry,ref=ghcr.io/servicenow/fast-llm:cache,mode=max

- name: Output build info
run: |
echo "Built Docker image with tags:"
Expand Down
9 changes: 5 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1.7-labs
FROM nvcr.io/nvidia/pytorch:25.05-py3
FROM nvcr.io/nvidia/pytorch:25.11-py3

# Install dependencies.
RUN apt-get update \
Expand Down Expand Up @@ -29,16 +29,17 @@ ENV PIP_CONSTRAINT=""
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?)
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d @ git+https://github.com/Dao-AILab/causal-conv1d@v1.5.4"
RUN MAX_JOBS=2 pip install --no-build-isolation mamba-ssm==2.2.6.post3
RUN MAX_JOBS=2 pip install --no-build-isolation "flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@67eee20c8503cd19eeb52aa1b99821308e9260c5"
# Copy dependency files with universal write permissions for all users.
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/
COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/

# Install dependencies within the virtual environment.
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.5.1

# Copy the remaining source code with universal write permissions.
COPY --chmod=777 ./Megatron-LM Megatron-LM
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/functional/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ActivationType(enum.StrEnum):
gelu = "gelu"
silu = "silu"
relu = "relu"
sigmoid = "sigmoid"
squared_relu = "squared_relu"
identity = "identity"

Expand Down Expand Up @@ -70,6 +71,7 @@ def _set_activation_fn_map() -> None:
ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
ActivationType.silu: torch.nn.functional.silu,
ActivationType.relu: torch.nn.functional.relu,
ActivationType.sigmoid: torch.nn.functional.sigmoid,
ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2),
ActivationType.identity: lambda x: x,
}
Expand All @@ -83,6 +85,7 @@ def _set_activation_fn_map() -> None:
ActivationType.relu: "relu",
ActivationType.squared_relu: "relu2",
ActivationType.identity: "identity",
ActivationType.sigmoid: "sigmoid",
}
_ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()}
_ACTIVATION_HF_NAMES_INV["gelu"] = ActivationType.gelu
Expand Down
7 changes: 7 additions & 0 deletions fast_llm/layers/common/normalization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.engine.base_model.config import ModuleConfig
from fast_llm.engine.config_utils.parameter import ParameterConfig, combine_lr_scales
from fast_llm.functional.config import ActivationType
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.utils import Assert

Expand Down Expand Up @@ -133,6 +134,12 @@ def module_class(self):
class GatedRMSNormalizationConfig(RMSNormalizationConfig):
_abstract = False

activation: ActivationType = Field(
default=ActivationType.silu,
desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.",
hint=FieldHint.core,
)

@property
def module_class(self):
from fast_llm.layers.common.normalization.normalization import GatedRMSNormalization
Expand Down
5 changes: 2 additions & 3 deletions fast_llm/layers/common/normalization/normalization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import abc

import torch
import torch.nn.functional as F

from fast_llm.config import Configurable
from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_
Expand Down Expand Up @@ -324,7 +323,7 @@ def _forward_fla(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor
gate,
self.weight,
None,
activation="silu",
activation=self._config.activation.hf_name,
eps=self._config.epsilon,
residual=None,
prenorm=False,
Expand All @@ -333,4 +332,4 @@ def _forward_fla(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor

def _forward_local(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
normalized = self._forward(input_)
return normalized * F.silu(gate)
return normalized * self._config.activation.activation_fn(gate)
107 changes: 88 additions & 19 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer
from fast_llm.engine.config_utils.parameter import ParameterConfig
from fast_llm.functional.config import ActivationType
from fast_llm.layers.block.config import BlockKwargs
from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig
from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig
Expand All @@ -15,6 +14,8 @@
import torch

from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2
from fast_llm.layers.ssm.gdn import GatedDeltaNet
from fast_llm.layers.ssm.kda import KimiDeltaAttention
from fast_llm.layers.ssm.mamba import Mamba
from fast_llm.tensor import ParameterMeta

Expand Down Expand Up @@ -84,36 +85,104 @@ class GatedDeltaNetConfig(MixerConfig):
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
norm_epsilon: float = Field(
default=1e-6,
desc="Epsilon used by the gated RMS norm.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
activation: ActivationType = Field(
default=ActivationType.silu,
desc="Activation used after the convolution.",
hint=FieldHint.architecture,
)

def _validate(self) -> None:
super()._validate()
Assert.multiple(self.value_heads, self.key_heads)

@property
def layer_class(self) -> "type":
def layer_class(self) -> "type[GatedDeltaNet]":
from fast_llm.layers.ssm.gdn import GatedDeltaNet

return GatedDeltaNet

def _validate(self) -> None:
super()._validate()


@config_class(dynamic_type={MixerConfig: "kda"})
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"kimi_delta_attention"

class KimiDeltaAttentionConfig(MixerConfig):
"""
Configuration for the KimiDeltaAttention mixer inspired by the Kimi Linear models.
"""

_abstract = False
normalization: GatedRMSNormalizationConfig = Field(
desc="Configuration for the gated normalization applied to the KDA output.",
hint=FieldHint.architecture,
)
q_projection_layer: AffineLinearConfig = Field(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

projection seems unnecessary in these fields.

desc="Projection that produces query vectors.",
hint=FieldHint.architecture,
)
k_projection_layer: AffineLinearConfig = Field(
desc="Projection that produces key vectors.",
hint=FieldHint.architecture,
)
v_projection_layer: AffineLinearConfig = Field(
desc="Projection that produces value vectors.",
hint=FieldHint.architecture,
)
f_a_projection_layer: AffineLinearConfig = Field(
desc="Projection used for the Delta gating pre-activation.",
hint=FieldHint.architecture,
)
f_b_projection_layer: AffineLinearConfig = Field(
desc="Projection used for the Delta gating expansion.",
hint=FieldHint.architecture,
)
g_a_projection_layer: AffineLinearConfig = Field(
desc="Projection used for the output gating pre-activation.",
hint=FieldHint.architecture,
)
g_b_projection_layer: AffineLinearConfig = Field(
desc="Projection used for the output gating expansion.",
hint=FieldHint.architecture,
)
beta_projection_layer: AffineLinearConfig = Field(
desc="Projection that produces the Beta gate.",
hint=FieldHint.architecture,
)
output_projection_layer: AffineLinearConfig = Field(
desc="Projection applied after the Delta recurrence and gated normalization.",
hint=FieldHint.architecture,
)
convolution_layer: CausalConv1dConfig = Field(
desc="Depth-wise convolution applied independently on each Q, K and V stream.",
hint=FieldHint.architecture,
)
dt_bias_weight: ParameterConfig = Field(
desc="Parameter configuration for the Delta gate bias.",
hint=FieldHint.architecture,
)
a_log_weight: ParameterConfig = Field(
desc="Parameter configuration for the decay rates.",
hint=FieldHint.architecture,
)

heads: int = Field(
default=16,
desc="Number of attention heads.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
head_dim: int = Field(
default=64,
desc="Dimension of each head.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)

@property
def layer_class(self) -> "type[KimiDeltaAttention]":
from fast_llm.layers.ssm.kda import KimiDeltaAttention

return KimiDeltaAttention

def _validate(self) -> None:
with self._set_implicit_default():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that's a good idea, it makes configs hard to understand. Better assume the user to specify these explicitly. (and most of the time we're creating from HF so that's not a problem)

if "epsilon" not in self.normalization._explicit_fields:
self.normalization.epsilon = 1.0e-5
if "activation" not in self.convolution_layer._explicit_fields:
self.convolution_layer.activation = "silu"
if "kernel_size" not in self.convolution_layer._explicit_fields:
self.convolution_layer.kernel_size = 4
if "activation" not in self.normalization._explicit_fields:
self.normalization.activation = "sigmoid"

super()._validate()

Expand Down
3 changes: 2 additions & 1 deletion fast_llm/layers/ssm/gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_
from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.config import ActivationType
from fast_llm.layers.block.config import BlockKwargs
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.block import BlockWithBias
Expand Down Expand Up @@ -204,7 +205,7 @@ def __init__(
self.convolution = self._config.convolution_layer.get_layer(
qkv_channels_dim,
default_add_bias=False,
default_activation=self._config.activation,
default_activation=ActivationType.silu,
lr_scale=self._lr_scale,
peft=self._peft,
)
Expand Down
Loading