Add KDA to external Apriel 2 modelling files and Fast-LLM converters#409
Add KDA to external Apriel 2 modelling files and Fast-LLM converters#409
Conversation
e1ba7e6 to
640a43d
Compare
External Module (fast_llm_external_models/apriel2): - Implement KimiDeltaAttention mixer using fla.ops.kda kernels - Add KIL (Kimi Initialization from LLM) converter: attention → KDA - Refactor converters.py with unified per-mixer plan functions - Add GatedRMSNormalization activation parameter (silu/sigmoid) - Add KDA to stochastic supernet and example surgery configs - Update train_supernet_small.yaml with runtime mixer switching demo Fast-LLM Core (fast_llm/models): - Add Apriel2KimiDeltaAttentionConverter for checkpoint import/export - Update StochasticMixer and Block converters for KDA support - Fix auto_map: use AutoModelForImageTextToText for VLM models Tests: - Refactor test architecture with shared fixtures (conftest.py) - Add comprehensive KDA tests (cache, equivalence, expression plans) - Remove redundant test_cache_routing.py (merged into test_cache.py) - Add KDA to apriel2_text_all_hybrid test config 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
640a43d to
310c311
Compare
There was a problem hiding this comment.
Pull request overview
This PR adds comprehensive KimiDeltaAttention (KDA) support to the Apriel2 architecture, including a new KIL (Kimi Initialization from LLM) converter that enables attention-to-KDA transformations. The implementation spans the external modeling module, core Fast-LLM converters, cache system enhancements, and extensive test coverage.
Key Changes:
- Full KDA mixer implementation with FLA kernel integration and tuple conv state handling
- KIL converter for attention→KDA weight transformation with GQA tiling support
- Cache system enhanced to handle KDA's triple-tuple conv states throughout beam search operations
- Refactored converter architecture with unified per-mixer plan functions replacing scattered logic
- VLM auto_map fix to use correct
AutoModelForImageTextToTextclass
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
modeling_apriel2.py |
Added complete KimiDeltaAttention class with q/k/v conv, gate projections, and FLA kernel integration |
converters.py |
Major refactor: unified per-mixer planners + new plan_kil_attention_to_kda converter |
cache.py |
Enhanced beam operations to handle KDA tuple conv states (q, k, v) |
gpt/conversion/apriel2.py |
Added Apriel2KimiDeltaAttentionConverter for Fast-LLM checkpoint import/export |
multimodal/conversion/apriel2.py |
Fixed VLM auto_map to use AutoModelForImageTextToText |
test_mixer_equivalence.py |
Added KDA equivalence tests vs FLA, determinism tests, comprehensive documentation |
test_cache.py |
Complete rewrite with 1258 lines covering all cache scenarios including KDA tuples |
test_expr_plan.py |
Added KIL plan tests for MHA and GQA scenarios |
| Example configs | Added KDA to stochastic supernet, comprehensive, and new hybrid_kil.yaml |
No critical issues found. The implementation is well-structured, thoroughly tested, and properly integrated across all system layers.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Remove unused `projection_size` variable in test_expr_plan.py - Remove unused `attention_config` parameter and unpacking in test_mixer_equivalence.py test_causal_vs_mistral - Add @requires_cuda to test_stochastic_supernet_yaml_end_to_end since KDA requires CUDA (FLA kernel fails on CPU-only environments) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Enhance StochasticMixer debug logging to include iteration number and use logger.info for consistency with other model debug logging - Increase bf16 forward pass tolerance from 1e-2/1e-3 to 1.5e-2/1.5e-3 to account for precision differences with KDA/GDN FLA kernels - Add commented model_debug_level option in test config for easier debugging of stochastic mixer selection 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When model_debug_level > 0, the vision encoder components would crash with shape mismatch errors (e.g., "1024 != 5120") because the debug logging tried to verify tensor shapes against incorrect hidden dims. The root cause: VisionKwargs.hidden_dims was set to the decoder hidden size (5120) but embeddings and encoder output vision hidden size (1024). Fix: - Expose _vision_hidden_dim (1024) in VisionEncoder alongside the existing _hidden_dim (5120, used for adapter output) - Use _vision_hidden_dim for the hidden_dims kwarg passed to vision encoder components (embeddings, encoder blocks) - For adapter MLP which projects from 1024 to 5120, pass dims=None when output_dim != hidden_dim so _debug infers dims from tensor shape - Make _get_meta robust to missing hidden_dims/sequence_q_dim in kwargs Also enables model_debug_level: 1 in train_supernet_small.yaml example. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Configure lr_scale: 0.0 for MLP, normalization, embeddings, head, and vision_encoder to freeze all components except the mixer during training - Add reference_models section with teacher model (attention-only) for activation-level distillation - Set activation_distillation_factor: 0.1 to guide alternative mixers (GDN, KDA) to produce similar activations to attention - Update prerequisites to include teacher model conversion step - Increase train_iters to 100 for extended training run 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Merge branch 'main' (3b50720) into tscholak/apriel2-kda Changes in this commit: - Refactor test_gdn_equivalence.py, test_kda_equivalence.py to follow consistent cookie-cutter pattern - Add new test_mamba_equivalence.py with parameterized tests for add_linear_biases × repeat_kv_before_conv configurations - Fix apriel2 model config: add missing auto_model_class for multimodal - Fix apriel2 skip_tests: add bf2_df2 (depends on skipped df4) - CausalConv1d refactor in modeling_apriel2.py Test pattern standardization: - All use try/except imports with @skipif decorators - All use _copy_weights() helper functions - All use Assert.rms_close() from fast_llm.utils - All use consistent constants (BATCH_SIZE=2, seed=42) - Removed debug prints 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
| self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs, bias=bias) | ||
| # Use None for dims when output_dim differs from hidden_dim (e.g., adapter projections) | ||
| # to let _debug infer dims from actual tensor shape | ||
| dims = None if self._output_dim != self._hidden_dim else kwargs.get(BlockKwargs.hidden_dims) |
There was a problem hiding this comment.
This won't work, it will produce incorrect results in distributed settings
| hidden_dims = { | ||
| dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) | ||
| } | ||
| hidden_dims = {} |
There was a problem hiding this comment.
These are required kwargs, why would they be missing?
| @@ -1,41 +1,60 @@ | |||
| """Test numerical equivalence between Fast-LLM GDN and Apriel2 GatedDeltaNet.""" | |||
There was a problem hiding this comment.
I also rewrote them in #408 and made them a lot simpler. Any significant change I need to keep in mind, other than the addition of mamba?
| # Micro-sequence split and sequence-first not supported for Mamba. | ||
| # TP excluded because no gradient reductions implemented for TP norm in GDN (use STP instead). | ||
| skip_tests=("sdp", "ms", "bf4", "df4", TP_NO_STP), | ||
| # bf2_df2 depends on df4, so must also be skipped. |
There was a problem hiding this comment.
| fast_out, _ = fast_layer(hidden_states, fast_kwargs) | ||
|
|
||
| # Compare outputs (slightly looser tolerance for Mamba due to numerical differences) | ||
| Assert.rms_close(fast_out, hf_out, 1e-4) |
There was a problem hiding this comment.
This is actually a ~1% difference. Are we ok with it?
Summary
fla.ops.kdakernelsauto_mapto useAutoModelForImageTextToTextfor VLM modelsChanges by Area
External Module (
fast_llm_external_models/apriel2)modeling_apriel2.py: Full KDA implementation with Q/K/V projections, convolutions, gating, and fla kernelsupport
conversion/converters.py: Refactored with per-mixer plan functions; added KIL convertercache.py: KDA state management supportexamples/hybrid_kil.yamlsurgery configstochastic_supernet.yaml,comprehensive.yaml,train_supernet_small.yamlFast-LLM Core (
fast_llm/models)gpt/conversion/apriel2.py:Apriel2KimiDeltaAttentionConverterfor checkpoint handlingmultimodal/conversion/apriel2.py: Fixedauto_mapfor proper VLM auto-class supportTests
conftest.pywith shared mixer fixturestest_cache.py(absorbedtest_cache_routing.py)test_mixer_equivalence.pyandtest_expr_plan.pyapriel2_text_all_hybridtest configTest plan
pytest tests/ -k "apriel2"- fast-llm apriel2 testspytest fast_llm_external_models/tests/test_apriel2/- external module teststrain_supernet_small.yamlinstructions to test full pipeline with runtime mixer switching