Conversation
jlamypoirier
left a comment
There was a problem hiding this comment.
Some comments, most also apply to GDA
| # The image is still compatible with any user id. | ||
| RUN useradd user | ||
| USER user | ||
| USER user No newline at end of file |
| super()._validate() | ||
|
|
||
|
|
||
| @config_class(dynamic_type={MixerConfig: "kda"}) |
There was a problem hiding this comment.
"kimi_delta_attention"
| desc="Configuration for the gated normalization applied to the KDA output.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| q_projection_layer: AffineLinearConfig = Field( |
There was a problem hiding this comment.
projection seems unnecessary in these fields.
| ) | ||
|
|
||
| @property | ||
| def layer_class(self) -> "type": |
There was a problem hiding this comment.
type["KimiDeltaAttention"]
| return KimiDeltaAttention | ||
|
|
||
| def _validate(self) -> None: | ||
| with self._set_implicit_default(): |
There was a problem hiding this comment.
Not sure that's a good idea, it makes configs hard to understand. Better assume the user to specify these explicitly. (and most of the time we're creating from HF so that's not a problem)
|
|
||
|
|
||
| @pytest.mark.slow | ||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA equivalence test needs CUDA") |
There was a problem hiding this comment.
pytest.mark.requires_cuda
| AprielHybridSSMConfig, KimiDeltaAttention = None, None | ||
|
|
||
|
|
||
| def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: |
There was a problem hiding this comment.
Please use get_stage, it already does this. See example here https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py#L264
Also please don't copy utils to every file, they can go in utils
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA equivalence test needs CUDA") | ||
| @pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") | ||
| @pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") | ||
| def test_fast_llm_kda_matches_apriel_forward(): |
There was a problem hiding this comment.
Not sure we need this test at all. test_huggingface_model already tests the equivalence
There was a problem hiding this comment.
I agree that we will not need those eventually.
test_huggingface_model seem to be a heavier integration test as compared to the isolate unit tests test_kda_equivalence and test_gda_equivalence. The letter ones are more useful for development.
Can we keep them for some time until gda and kda implementations are production tested.
There was a problem hiding this comment.
I believe these tests are extremely valuable and should remain.
In fact, I think we should extend them to non-FLA backup implementations of GDN and KDA.
| ModelTestingGroup.convert: ModelTestingGroupAction.normal, | ||
| ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.distributed: ModelTestingGroupAction.normal, |
There was a problem hiding this comment.
We might want to test once and leave as unimportant, this has a huge impact on testing time.
There was a problem hiding this comment.
leaving them here for now until we’ve used KDA and GDN enough to be confident they’re stable and free of issues
Update to nvcr.io/nvidia/pytorch:25.11-py3 which includes: - PyTorch 2.10 - CUDA 13.0 - flash-attn 2.7.4.post1 (pre-installed, no compilation needed) Dependency updates: - causal-conv1d: v1.5.4 (was pinned to commit 2a288a1) - mamba-ssm: 2.2.6.post3 (was pinned to commit 4a8a2a2) - flash-linear-attention: pin to commit 67eee20 (was @main) - flash-attn: 2.7.4.post1 to match base image (was 2.7.3) - triton: 3.5.1 in Dockerfile (was 3.1.0) These updates enable Kimi Delta Attention (KDA) support via the flash-linear-attention library. The pinned versions are tested and working, unlike the nightly/unpinned approach in #395. Note: Dropless MoE kernel remains broken with triton >= 3.2.0 and needs a complete rewrite (also limited to 32 experts). This is tracked separately and doesn't block KDA work. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
tscholak
left a comment
There was a problem hiding this comment.
very good work, thank you!
I'd like us to have a non-FLA fallback for GDN and KDA, similar to our torch-compiled attention backup implementation.
You should be able to reuse much of the GDN torch code from upstream Qwen3Next and similarly from kimi linear.
And then we add a config option for GDN and KDA like so:
Fast-LLM/fast_llm/layers/attention/config.py
Lines 34 to 37 in cc009a4
| # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) | ||
| return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) | ||
|
|
||
| def _forward( |
There was a problem hiding this comment.
can we please have a torch-only compiled fallback in case fla isn't available?
| fast_layer.preprocess(fast_kwargs) | ||
| fast_out, _ = fast_layer(hidden_states, fast_kwargs) | ||
|
|
||
| torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) |
There was a problem hiding this comment.
can we please structure this test like:
Fast-LLM/tests/test_attention.py
Line 60 in cc009a4
We should add a backup implementation for gdn in case fla isn't available
| fast_layer.preprocess(fast_kwargs) | ||
| fast_out, _ = fast_layer(hidden_states, fast_kwargs) | ||
|
|
||
| torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) |
There was a problem hiding this comment.
can we please structure this test like:
Fast-LLM/tests/test_attention.py
Line 60 in cc009a4
We should add a backup implementation for kda in case fla isn't available
| torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) | ||
|
|
||
|
|
||
| if __name__ == "__main__": |
✨ Description
Should be merged after GDN #392 .
Adding KDA mixer from Kimi Lienar.
Note, for now this requires nightly triton and pytorch, see: https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md.
Merged #404 here. Tests for both
hybrid_kdaandapriel2_text_gdn_hybridmodels pass when using the new docker image on toolkit.🔍 Type of change
Select all that apply:
📝 Changes
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.