Skip to content

Kda mixer#395

Merged
oleksost merged 51 commits intomainfrom
kda
Dec 8, 2025
Merged

Kda mixer#395
oleksost merged 51 commits intomainfrom
kda

Conversation

@oleksost
Copy link
Copy Markdown
Contributor

@oleksost oleksost commented Nov 26, 2025

✨ Description

Should be merged after GDN #392 .

Adding KDA mixer from Kimi Lienar.

Note, for now this requires nightly triton and pytorch, see: https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md.

Merged #404 here. Tests for both hybrid_kda and apriel2_text_gdn_hybrid models pass when using the new docker image on toolkit.

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

📊 Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


🗒️ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@oleksost oleksost marked this pull request as ready for review November 26, 2025 20:40
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments, most also apply to GDA

Comment thread Dockerfile Outdated
# The image is still compatible with any user id.
RUN useradd user
USER user
USER user No newline at end of file
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary diff

super()._validate()


@config_class(dynamic_type={MixerConfig: "kda"})
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"kimi_delta_attention"

desc="Configuration for the gated normalization applied to the KDA output.",
hint=FieldHint.architecture,
)
q_projection_layer: AffineLinearConfig = Field(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

projection seems unnecessary in these fields.

Comment thread fast_llm/layers/ssm/config.py Outdated
)

@property
def layer_class(self) -> "type":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type["KimiDeltaAttention"]

return KimiDeltaAttention

def _validate(self) -> None:
with self._set_implicit_default():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that's a good idea, it makes configs hard to understand. Better assume the user to specify these explicitly. (and most of the time we're creating from HF so that's not a problem)

Comment thread tests/layers/test_kda_equivalence.py Outdated


@pytest.mark.slow
@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA equivalence test needs CUDA")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest.mark.requires_cuda

Comment thread tests/layers/test_kda_equivalence.py Outdated
AprielHybridSSMConfig, KimiDeltaAttention = None, None


def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use get_stage, it already does this. See example here https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py#L264

Also please don't copy utils to every file, they can go in utils

@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA equivalence test needs CUDA")
@pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing")
@pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available")
def test_fast_llm_kda_matches_apriel_forward():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need this test at all. test_huggingface_model already tests the equivalence

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we will not need those eventually.

test_huggingface_model seem to be a heavier integration test as compared to the isolate unit tests test_kda_equivalence and test_gda_equivalence. The letter ones are more useful for development.

Can we keep them for some time until gda and kda implementations are production tested.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe these tests are extremely valuable and should remain.
In fact, I think we should extend them to non-FLA backup implementations of GDN and KDA.

ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.distributed: ModelTestingGroupAction.normal,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to test once and leave as unimportant, this has a huge impact on testing time.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaving them here for now until we’ve used KDA and GDN enough to be confident they’re stable and free of issues

Update to nvcr.io/nvidia/pytorch:25.11-py3 which includes:
- PyTorch 2.10
- CUDA 13.0
- flash-attn 2.7.4.post1 (pre-installed, no compilation needed)

Dependency updates:
- causal-conv1d: v1.5.4 (was pinned to commit 2a288a1)
- mamba-ssm: 2.2.6.post3 (was pinned to commit 4a8a2a2)
- flash-linear-attention: pin to commit 67eee20 (was @main)
- flash-attn: 2.7.4.post1 to match base image (was 2.7.3)
- triton: 3.5.1 in Dockerfile (was 3.1.0)

These updates enable Kimi Delta Attention (KDA) support via the
flash-linear-attention library. The pinned versions are tested and
working, unlike the nightly/unpinned approach in #395.

Note: Dropless MoE kernel remains broken with triton >= 3.2.0 and
needs a complete rewrite (also limited to 32 experts). This is
tracked separately and doesn't block KDA work.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@oleksost oleksost requested a review from jlamypoirier December 8, 2025 16:29
Copy link
Copy Markdown
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very good work, thank you!

I'd like us to have a non-FLA fallback for GDN and KDA, similar to our torch-compiled attention backup implementation.
You should be able to reuse much of the GDN torch code from upstream Qwen3Next and similarly from kimi linear.
And then we add a config option for GDN and KDA like so:

class AttentionImplementation(enum.StrEnum):
auto = "auto"
flash = "flash"
backup = "backup"

# same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim)

def _forward(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please have a torch-only compiled fallback in case fla isn't available?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added these as todos in #406.

fast_layer.preprocess(fast_kwargs)
fast_out, _ = fast_layer(hidden_states, fast_kwargs)

torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please structure this test like:

def test_attention_implementations(cross_document_attention: bool, causal: bool, window_size: int | None):

We should add a backup implementation for gdn in case fla isn't available

fast_layer.preprocess(fast_kwargs)
fast_out, _ = fast_layer(hidden_states, fast_kwargs)

torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please structure this test like:

def test_attention_implementations(cross_document_attention: bool, causal: bool, window_size: int | None):

We should add a backup implementation for kda in case fla isn't available

Comment thread tests/layers/test_kda_equivalence.py Outdated
torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove that

Comment thread tests/layers/test_gdn_equivalence.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove that

@oleksost oleksost merged commit 9d12e9c into main Dec 8, 2025
4 checks passed
@oleksost oleksost deleted the kda branch December 8, 2025 19:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants