-
Notifications
You must be signed in to change notification settings - Fork 43
Kda mixer #395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kda mixer #395
Changes from all commits
1a219c4
5242eb6
bec22de
7f79909
8636f09
a20c958
8ac5167
f1a51f2
3b367d8
c48d4ee
e2bb25c
d4f9b85
8017a80
1e01601
ca8cb5c
694d287
d3bd916
3ff7799
9a53c5b
80041ce
67a234a
d6677b0
a75cd9f
5d3b6d0
0d41dce
6e2c1fe
8938a1d
6c2bd46
28a6176
2a30bac
cad93ab
8f957a4
82c9cc4
d25994e
3651b06
7b31e78
5a44097
a164a2b
c4aa9b1
a8849cb
5f32ba7
1dde2a9
685f351
d33d6d7
05abc03
7b30a36
eb52cc7
372771b
aea8b1e
407be82
2ce4c07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,6 @@ | |
| from fast_llm.config import Field, FieldHint, check_field, config_class | ||
| from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer | ||
| from fast_llm.engine.config_utils.parameter import ParameterConfig | ||
| from fast_llm.functional.config import ActivationType | ||
| from fast_llm.layers.block.config import BlockKwargs | ||
| from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig | ||
| from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig | ||
|
|
@@ -15,6 +14,8 @@ | |
| import torch | ||
|
|
||
| from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 | ||
| from fast_llm.layers.ssm.gdn import GatedDeltaNet | ||
| from fast_llm.layers.ssm.kda import KimiDeltaAttention | ||
| from fast_llm.layers.ssm.mamba import Mamba | ||
| from fast_llm.tensor import ParameterMeta | ||
|
|
||
|
|
@@ -84,36 +85,104 @@ class GatedDeltaNetConfig(MixerConfig): | |
| hint=FieldHint.architecture, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
| norm_epsilon: float = Field( | ||
| default=1e-6, | ||
| desc="Epsilon used by the gated RMS norm.", | ||
| hint=FieldHint.architecture, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
| activation: ActivationType = Field( | ||
| default=ActivationType.silu, | ||
| desc="Activation used after the convolution.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
|
|
||
| def _validate(self) -> None: | ||
| super()._validate() | ||
| Assert.multiple(self.value_heads, self.key_heads) | ||
|
|
||
| @property | ||
| def layer_class(self) -> "type": | ||
| def layer_class(self) -> "type[GatedDeltaNet]": | ||
| from fast_llm.layers.ssm.gdn import GatedDeltaNet | ||
|
|
||
| return GatedDeltaNet | ||
|
|
||
| def _validate(self) -> None: | ||
| super()._validate() | ||
|
|
||
|
|
||
| @config_class(dynamic_type={MixerConfig: "kda"}) | ||
| class KimiDeltaAttentionConfig(MixerConfig): | ||
| """ | ||
| Configuration for the KimiDeltaAttention mixer inspired by the Kimi Linear models. | ||
| """ | ||
|
|
||
| _abstract = False | ||
| normalization: GatedRMSNormalizationConfig = Field( | ||
| desc="Configuration for the gated normalization applied to the KDA output.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| q_projection_layer: AffineLinearConfig = Field( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| desc="Projection that produces query vectors.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| k_projection_layer: AffineLinearConfig = Field( | ||
| desc="Projection that produces key vectors.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| v_projection_layer: AffineLinearConfig = Field( | ||
| desc="Projection that produces value vectors.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| f_a_projection_layer: AffineLinearConfig = Field( | ||
| desc="Projection used for the Delta gating pre-activation.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| f_b_projection_layer: AffineLinearConfig = Field( | ||
| desc="Projection used for the Delta gating expansion.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| g_a_projection_layer: AffineLinearConfig = Field( | ||
| desc="Projection used for the output gating pre-activation.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| g_b_projection_layer: AffineLinearConfig = Field( | ||
| desc="Projection used for the output gating expansion.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| beta_projection_layer: AffineLinearConfig = Field( | ||
| desc="Projection that produces the Beta gate.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| output_projection_layer: AffineLinearConfig = Field( | ||
| desc="Projection applied after the Delta recurrence and gated normalization.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| convolution_layer: CausalConv1dConfig = Field( | ||
| desc="Depth-wise convolution applied independently on each Q, K and V stream.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| dt_bias_weight: ParameterConfig = Field( | ||
| desc="Parameter configuration for the Delta gate bias.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| a_log_weight: ParameterConfig = Field( | ||
| desc="Parameter configuration for the decay rates.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
|
|
||
| heads: int = Field( | ||
| default=16, | ||
| desc="Number of attention heads.", | ||
| hint=FieldHint.architecture, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
| head_dim: int = Field( | ||
| default=64, | ||
| desc="Dimension of each head.", | ||
| hint=FieldHint.architecture, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| @property | ||
| def layer_class(self) -> "type[KimiDeltaAttention]": | ||
| from fast_llm.layers.ssm.kda import KimiDeltaAttention | ||
|
|
||
| return KimiDeltaAttention | ||
|
|
||
| def _validate(self) -> None: | ||
| with self._set_implicit_default(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure that's a good idea, it makes configs hard to understand. Better assume the user to specify these explicitly. (and most of the time we're creating from HF so that's not a problem) |
||
| if "epsilon" not in self.normalization._explicit_fields: | ||
| self.normalization.epsilon = 1.0e-5 | ||
| if "activation" not in self.convolution_layer._explicit_fields: | ||
| self.convolution_layer.activation = "silu" | ||
| if "kernel_size" not in self.convolution_layer._explicit_fields: | ||
| self.convolution_layer.kernel_size = 4 | ||
| if "activation" not in self.normalization._explicit_fields: | ||
| self.normalization.activation = "sigmoid" | ||
|
|
||
| super()._validate() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"kimi_delta_attention"