diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 3a0f7cc59..a1942cab1 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -83,9 +83,12 @@ def _get_meta( return None if dims is None: dims = tuple(f"dim_{i}" for i in range(tensor.ndim)) - hidden_dims = { - dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) - } + hidden_dims = {} + if BlockKwargs.hidden_dims in kwargs: + for dim in kwargs[BlockKwargs.hidden_dims]: + hidden_dims[dim.name] = dim + if BlockKwargs.sequence_q_dim in kwargs: + hidden_dims[kwargs[BlockKwargs.sequence_q_dim].name] = kwargs[BlockKwargs.sequence_q_dim] return TensorMeta.from_dims( tuple( ( diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index b4da15b45..882963ce9 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -137,5 +137,8 @@ def _forward( transposed_layer_2_weight=self.layer_2.transposed_weight, ) bias = self.layer_2.bias if self._parallel_dim.group else None - self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs, bias=bias) + # Use None for dims when output_dim differs from hidden_dim (e.g., adapter projections) + # to let _debug infer dims from actual tensor shape + dims = None if self._output_dim != self._hidden_dim else kwargs.get(BlockKwargs.hidden_dims) + self._debug(out, None, dims, kwargs, bias=bias) return out, bias diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 32633f218..673c64034 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -118,7 +118,13 @@ def _forward( mixer_name = self._sample_mixer_name(kwargs) if get_model_debug_level() > 0: - logger.debug(f"StochasticMixer selecting mixer {mixer_name}: {type(self.mixers[mixer_name]).__name__}") + from fast_llm.layers.block.config import BlockKwargs + + iteration = kwargs.get(BlockKwargs.iteration, "?") + logger.info( + f"StochasticMixer iter={iteration} selecting mixer '{mixer_name}' " + f"({type(self.mixers[mixer_name]).__name__})" + ) return self.mixers[mixer_name]._forward(input_, kwargs, losses, metrics) diff --git a/fast_llm/layers/vision/vision_encoder.py b/fast_llm/layers/vision/vision_encoder.py index 03acfdde7..1bd499f97 100644 --- a/fast_llm/layers/vision/vision_encoder.py +++ b/fast_llm/layers/vision/vision_encoder.py @@ -27,7 +27,9 @@ def __init__( peft: PeftConfig | None, ): super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) - vision_hidden_dim = TensorDim("hidden", self._config.hidden_size) + # Internal hidden dimension for embeddings and encoder (may differ from output hidden_dim for adapter) + self._vision_hidden_dim = TensorDim("hidden", self._config.hidden_size) + vision_hidden_dim = self._vision_hidden_dim self.embeddings = self._config.embeddings.get_layer( distributed_config, vision_hidden_dim, diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 1b60e8834..7682196c8 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -9,7 +9,7 @@ from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig -from fast_llm.layers.ssm.config import GatedDeltaNetConfig, Mamba2Config +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, Mamba2Config from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( @@ -271,6 +271,144 @@ def get_converters( ] +class Apriel2KimiDeltaAttentionConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + result = { + "type": "kda", + "heads": config["heads"], + "head_dim": config["head_dim"], + } + if "convolution_layer" in config: + result["convolution_layer"] = config["convolution_layer"] + if "normalization" in config: + result["normalization"] = config["normalization"] + return result + + @classmethod + def export_config(cls, config: KimiDeltaAttentionConfig) -> dict: + return { + "type": "kda", + "heads": config.heads, + "head_dim": config.head_dim, + "convolution_layer": { + "kernel_size": config.convolution_layer.kernel_size, + }, + "normalization": { + "epsilon": config.normalization.epsilon, + }, + } + + @classmethod + def get_converters( + cls, + config: KimiDeltaAttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + # Fast-LLM KDA uses abbreviated names matching the external module: + # q_proj, k_proj, v_proj, q_conv, k_conv, v_conv, f_a_proj, f_b_proj, + # g_a_proj, g_b_proj, beta_proj, o_proj, A_log, dt_bias, norm + return [ + # Q/K/V projections + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.q_proj", + f"{hf_prefix}.q_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.k_proj", + f"{hf_prefix}.k_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.v_proj", + f"{hf_prefix}.v_proj", + False, + drop_on_export=drop_on_export, + ), + # Convolutions (Q, K, V) + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.q_conv", + f"{hf_prefix}.q_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.k_conv", + f"{hf_prefix}.k_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.v_conv", + f"{hf_prefix}.v_conv", + False, + drop_on_export=drop_on_export, + ), + # Gate projections (f_a, f_b, g_a, g_b) + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.f_a_proj", + f"{hf_prefix}.f_a_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.f_b_proj", + f"{hf_prefix}.f_b_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.g_a_proj", + f"{hf_prefix}.g_a_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.g_b_proj", + f"{hf_prefix}.g_b_proj", + False, + drop_on_export=drop_on_export, + ), + # Beta projection + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.beta_proj", + f"{hf_prefix}.beta_proj", + False, + drop_on_export=drop_on_export, + ), + # Output projection + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.o_proj", + f"{hf_prefix}.o_proj", + False, + drop_on_export=drop_on_export, + ), + # Learnable parameters + get_parameter_converter( + f"{fast_llm_prefix}.A_log", + f"{hf_prefix}.A_log", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.dt_bias", + f"{hf_prefix}.dt_bias", + drop_on_export=drop_on_export, + ), + # Normalization + *LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm", + f"{hf_prefix}.norm", + drop_on_export=drop_on_export, + ), + ] + + class Apriel2StochasticMixerConverter: @classmethod def import_config(cls, config: dict) -> dict: @@ -283,6 +421,8 @@ def import_config(cls, config: dict) -> dict: mixers[name] = Apriel2MambaConverter.import_config(sub_mixer_config) elif mixer_type == "gdn": mixers[name] = Apriel2GatedDeltaNetConverter.import_config(sub_mixer_config) + elif mixer_type == "kda": + mixers[name] = Apriel2KimiDeltaAttentionConverter.import_config(sub_mixer_config) else: raise ValueError(f"Unknown sub-mixer type: {mixer_type}") @@ -306,6 +446,8 @@ def export_config(cls, config: StochasticMixerConfig) -> dict: mixers[name] = Apriel2MambaConverter.export_config(sub_mixer) elif mixer_type is GatedDeltaNetConfig: mixers[name] = Apriel2GatedDeltaNetConverter.export_config(sub_mixer) + elif mixer_type is KimiDeltaAttentionConfig: + mixers[name] = Apriel2KimiDeltaAttentionConverter.export_config(sub_mixer) else: raise ValueError(f"Unknown sub-mixer type: {mixer_type}") @@ -336,6 +478,9 @@ def get_converters( elif mixer_type is GatedDeltaNetConfig: converter_class = Apriel2GatedDeltaNetConverter hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" + elif mixer_type is KimiDeltaAttentionConfig: + converter_class = Apriel2KimiDeltaAttentionConverter + hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" else: raise ValueError(f"Unknown sub-mixer type: {mixer_type}") converters.extend( @@ -364,6 +509,8 @@ def import_config(cls, config: dict, block_config: dict) -> dict: mixer = Apriel2StochasticMixerConverter.import_config(mixer_config) elif mixer_type == "gdn": mixer = Apriel2GatedDeltaNetConverter.import_config(mixer_config) + elif mixer_type == "kda": + mixer = Apriel2KimiDeltaAttentionConverter.import_config(mixer_config) else: raise ValueError(f"Unknown mixer type: {mixer_type}") @@ -404,6 +551,8 @@ def export_config(cls, config: DecoderBlockConfig) -> dict: mixer = Apriel2StochasticMixerConverter.export_config(config.mixer) elif mixer_type is GatedDeltaNetConfig: mixer = Apriel2GatedDeltaNetConverter.export_config(config.mixer) + elif mixer_type is KimiDeltaAttentionConfig: + mixer = Apriel2KimiDeltaAttentionConverter.export_config(config.mixer) else: raise ValueError(f"Unknown mixer type: {mixer_type}") @@ -460,6 +609,9 @@ def get_converters( elif mixer_type is GatedDeltaNetConfig: converter_class = Apriel2GatedDeltaNetConverter hf_mixer_prefix = f"{hf_prefix}.mixer" + elif mixer_type is KimiDeltaAttentionConfig: + converter_class = Apriel2KimiDeltaAttentionConverter + hf_mixer_prefix = f"{hf_prefix}.mixer" else: raise ValueError(f"Unknown mixer type: {mixer_type}") diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 88ea01220..b4147a8bf 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -406,7 +406,7 @@ def _export_config(cls, config: MultiModalModelConfig) -> dict[str, typing.Any]: "auto_map": { "AutoConfig": "configuration_apriel2.Apriel2Config", "AutoModel": "modeling_apriel2.Apriel2Model", - "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForConditionalGeneration", + "AutoModelForImageTextToText": "modeling_apriel2.Apriel2ForConditionalGeneration", }, }, ) diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index f8251e212..890d5760e 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -134,10 +134,11 @@ def preprocess_meta( TensorDim("patch_width", self._config.vision_encoder.embeddings.patch_width), ) ) + # Use vision encoder's internal hidden dim (for embeddings/encoder), not the output dim (for adapter) hidden_dims = ( - (hidden_batch_and_sequence_q_dim, scalar_dim, self.vision_encoder._hidden_dim) + (hidden_batch_and_sequence_q_dim, scalar_dim, self.vision_encoder._vision_hidden_dim) if (sequence_first := kwargs[LanguageModelKwargs.sequence_first]) - else (scalar_dim, hidden_batch_and_sequence_q_dim, self.vision_encoder._hidden_dim) + else (scalar_dim, hidden_batch_and_sequence_q_dim, self.vision_encoder._vision_hidden_dim) ) kwargs[self._vision_encoder_namespace] = { VisionKwargs.sequence_first: sequence_first, diff --git a/fast_llm_external_models/apriel2/cache.py b/fast_llm_external_models/apriel2/cache.py index 3181a4268..86c67a085 100644 --- a/fast_llm_external_models/apriel2/cache.py +++ b/fast_llm_external_models/apriel2/cache.py @@ -162,7 +162,11 @@ def _reorder_cache_obj(self, cache, beam_idx): cache.value = cache.value.index_select(0, beam_idx.to(cache.value.device)) elif isinstance(cache, _SSMCache): if cache.conv is not None: - cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device)) + # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states + if isinstance(cache.conv, tuple): + cache.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in cache.conv) + else: + cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device)) if cache.recurrent is not None: cache.recurrent = cache.recurrent.index_select(0, beam_idx.to(cache.recurrent.device)) @@ -208,7 +212,11 @@ def _batch_repeat_cache_obj(self, cache, repeats): cache.value = cache.value.repeat_interleave(repeats, dim=0) elif isinstance(cache, _SSMCache): if cache.conv is not None: - cache.conv = cache.conv.repeat_interleave(repeats, dim=0) + # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states + if isinstance(cache.conv, tuple): + cache.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in cache.conv) + else: + cache.conv = cache.conv.repeat_interleave(repeats, dim=0) if cache.recurrent is not None: cache.recurrent = cache.recurrent.repeat_interleave(repeats, dim=0) @@ -227,7 +235,11 @@ def _batch_select_cache_obj(self, cache, indices): cache.value = cache.value.index_select(0, indices.to(cache.value.device)) elif isinstance(cache, _SSMCache): if cache.conv is not None: - cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device)) + # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states + if isinstance(cache.conv, tuple): + cache.conv = tuple(c.index_select(0, indices.to(c.device)) for c in cache.conv) + else: + cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device)) if cache.recurrent is not None: cache.recurrent = cache.recurrent.index_select(0, indices.to(cache.recurrent.device)) @@ -274,11 +286,17 @@ def max_batch_size(self): if isinstance(cache, _AttentionCache) and cache.key is not None: return cache.key.shape[0] if isinstance(cache, _SSMCache) and cache.conv is not None: + # Handle both single tensor and tuple conv states + if isinstance(cache.conv, tuple): + return cache.conv[0].shape[0] return cache.conv.shape[0] else: if isinstance(layer, _AttentionCache) and layer.key is not None: return layer.key.shape[0] if isinstance(layer, _SSMCache) and layer.conv is not None: + # Handle both single tensor and tuple conv states + if isinstance(layer.conv, tuple): + return layer.conv[0].shape[0] return layer.conv.shape[0] return None diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py index 633125e86..983a632e0 100644 --- a/fast_llm_external_models/apriel2/conversion/__init__.py +++ b/fast_llm_external_models/apriel2/conversion/__init__.py @@ -120,8 +120,9 @@ # Plan builders (generic) from fast_llm_external_models.apriel2.conversion.converters import ( - plan_attention_to_gated_delta_net, plan_mil_attention_to_mamba, + plan_dil_attention_to_gdn, + plan_kil_attention_to_kda, plan_surgery, ) @@ -170,7 +171,8 @@ # Plan builders (generic) "plan_surgery", "plan_mil_attention_to_mamba", - "plan_attention_to_gated_delta_net", + "plan_dil_attention_to_gdn", + "plan_kil_attention_to_kda", # Config composition "compose_configs", # Source-specific converters diff --git a/fast_llm_external_models/apriel2/conversion/config.py b/fast_llm_external_models/apriel2/conversion/config.py index 74089c3fa..48f8ff44b 100644 --- a/fast_llm_external_models/apriel2/conversion/config.py +++ b/fast_llm_external_models/apriel2/conversion/config.py @@ -31,6 +31,7 @@ - attention → sliding_window: preserve heads, head_groups, head_size - attention → gdn: heads → value_heads, head_groups → key_heads - attention → mamba: derive d_inner, d_xb, dt_rank from hidden_size + - attention → kda: preserve heads, head_size → head_dim **Stochastic Mixer Composition** Two semantics based on whether surgery declares `type: stochastic`: @@ -439,6 +440,22 @@ def _compose_single_mixer(source: dict, surgery: dict, hidden_size: int) -> dict result["init"] = surgery["init"] return result + elif target_type == "kda": + # Attention → KDA: derive heads/head_dim from attention geometry + result = { + "type": "kda", + "heads": surgery.get("heads", heads), + "head_dim": surgery.get("head_dim", head_size), + } + # Copy KDA-specific fields from surgery + for key in ["convolution_layer", "normalization"]: + if key in surgery: + result[key] = surgery[key] + # Preserve init + if "init" in surgery: + result["init"] = surgery["init"] + return result + # Fallback: start fresh with surgery, no inheritance result = copy.deepcopy(surgery) result["type"] = target_type diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py index 9b0afeec3..6d1350c54 100644 --- a/fast_llm_external_models/apriel2/conversion/converters.py +++ b/fast_llm_external_models/apriel2/conversion/converters.py @@ -13,7 +13,7 @@ architecture modifications (adding Mamba layers, stochastic mixers, etc.). The surgery_spec's `init` field controls weight handling: - - `init: transfer` → use converters (MIL, DIL, passthrough) + - `init: transfer` → use converters (MIL, DIL, KIL, passthrough) - `init: random` → use random initialization If `init: transfer` is requested but no converter exists for the type pair @@ -38,6 +38,10 @@ Converts attention → gated_delta_net by mapping Q/K/V/O projections to the fused in_proj_qkvz and out_proj, respecting GQA head grouping. +**KIL (Kimi Initialization from LLM)** + Converts attention → kda by mapping Q/K/V/O projections directly, + with random initialization for gates, convolutions, and learnable params. + Stochastic Mixer Handling ========================= @@ -68,8 +72,313 @@ ) +# ============================================================================= +# SECTION 1: Per-Mixer Plan Functions +# ============================================================================= +# Each mixer type has ONE function that handles both random init and passthrough. +# This is the single source of truth for each mixer's weight schema. + + +def _plan_attention_mixer( + *, + prefix: W, + config: dict, + hidden_size: int, + source_prefix: W | None = None, +) -> ExprPlan: + """Plan for attention/sliding_window mixer. + + Weight schema: + - q_proj.weight: (q_size, hidden_size) + - k_proj.weight: (kv_size, hidden_size) + - v_proj.weight: (kv_size, hidden_size) + - o_proj.weight: (hidden_size, q_size) + + Args: + prefix: Target weight path prefix. + config: Mixer config dict. + hidden_size: Model hidden size. + source_prefix: If provided, passthrough from source. If None, random init. + """ + if source_prefix is not None: + # Passthrough + return ExprPlan(mappings={ + prefix / proj / "weight": Ref(key=source_prefix / proj / "weight") + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] + }) + + # Random init + heads = config["heads"] + head_groups = config["head_groups"] + head_size = config["head_size"] + q_size = heads * head_size + kv_size = head_groups * head_size + + return ExprPlan(mappings={ + prefix / "q_proj" / "weight": Init(shape=(q_size, hidden_size), init_type="kaiming"), + prefix / "k_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"), + prefix / "v_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"), + prefix / "o_proj" / "weight": Init(shape=(hidden_size, q_size), init_type="kaiming"), + }) + + +def _plan_mamba_mixer( + *, + prefix: W, + config: dict, + hidden_size: int, + source_prefix: W | None = None, +) -> ExprPlan: + """Plan for mamba mixer. + + Weight schema: + - in_proj.weight: (2*d_inner + 2*d_xb, hidden_size) + - out_proj.weight: (hidden_size, d_inner) + - dt_in_proj.weight: (dt_rank, hidden_size) + - dt_proj.weight: (d_inner, dt_rank) + - dt_proj.bias: (d_inner,) [optional] + - conv1d.weight: (conv_channels, 1, d_conv) + - conv1d.bias: (conv_channels,) [optional] + - A_log: (d_inner, d_state) + - D: (d_inner,) + + Args: + prefix: Target weight path prefix. + config: Mixer config dict. + hidden_size: Model hidden size. + source_prefix: If provided, passthrough from source. If None, random init. + """ + if source_prefix is not None: + # Passthrough - include all possible weights + return ExprPlan(mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "in_proj.weight", + "out_proj.weight", + "dt_in_proj.weight", + "dt_proj.weight", + "dt_proj.bias", + "conv1d.weight", + "conv1d.bias", + "A_log", + "D", + ] + }) + + # Random init + d_inner = config["d_inner"] + d_state = config["d_state"] + dt_rank = config["dt_rank"] + d_xb = config["d_xb"] + d_conv = config["d_conv"] + repeat_kv_before_conv = config["repeat_kv_before_conv"] + conv_bias = config["conv_bias"] + dt_bias = config["dt_proj_bias"] + dt_min = config["dt_min"] + dt_max = config["dt_max"] + dt_init_floor = config["dt_init_floor"] + + conv_channels = d_inner if repeat_kv_before_conv else d_xb + + mappings: dict[W, Expr] = { + prefix / "in_proj" / "weight": Init( + shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming" + ), + prefix / "out_proj" / "weight": Init(shape=(hidden_size, d_inner), init_type="kaiming"), + prefix / "dt_in_proj" / "weight": Init(shape=(dt_rank, hidden_size), init_type="kaiming"), + prefix / "dt_proj" / "weight": Init(shape=(d_inner, dt_rank), init_type="kaiming"), + prefix / "conv1d" / "weight": Init(shape=(conv_channels, 1, d_conv), init_type="kaiming"), + prefix / "A_log": Init(shape=(d_inner, d_state), init_type="s4d"), + prefix / "D": Init(shape=(d_inner,), init_type="ones"), + } + + if conv_bias: + mappings[prefix / "conv1d" / "bias"] = Init(shape=(conv_channels,), init_type="zeros") + if dt_bias: + mappings[prefix / "dt_proj" / "bias"] = Init( + shape=(d_inner,), + init_type="dt_bias", + init_params={"dt_min": dt_min, "dt_max": dt_max, "dt_init_floor": dt_init_floor}, + ) + + return ExprPlan(mappings=mappings) + + +def _plan_gdn_mixer( + *, + prefix: W, + config: dict, + hidden_size: int, + source_prefix: W | None = None, +) -> ExprPlan: + """Plan for gated_delta_net (GDN) mixer. + + Weight schema: + - in_proj_qkvz.weight: (qkvz_size, hidden_size) + - in_proj_ba.weight: (2*num_v_heads, hidden_size) + - out_proj.weight: (hidden_size, value_dim) + - convolution.weight: (conv_dim, 1, kernel_size) + - A_log: (num_v_heads,) + - dt_bias: (num_v_heads,) + - norm.weight: (head_v_dim,) + + Args: + prefix: Target weight path prefix. + config: Mixer config dict. + hidden_size: Model hidden size. + source_prefix: If provided, passthrough from source. If None, random init. + """ + if source_prefix is not None: + # Passthrough + return ExprPlan(mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "in_proj_qkvz.weight", + "in_proj_ba.weight", + "out_proj.weight", + "convolution.weight", + "A_log", + "dt_bias", + "norm.weight", + ] + }) + + # Random init + num_v_heads = config["value_heads"] + num_k_heads = config["key_heads"] + head_k_dim = config["key_head_dim"] + head_v_dim = config["value_head_dim"] + conv_kernel_size = config["convolution_layer"]["kernel_size"] + + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + conv_dim = key_dim * 2 + value_dim + qkvz_size = key_dim * 2 + value_dim * 2 # Q, K both key_dim; V, Z both value_dim + + return ExprPlan(mappings={ + prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"), + prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"), + prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"), + prefix / "convolution" / "weight": Init( + shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv" + ), + prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), + prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), + prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), + }) + + +def _plan_kda_mixer( + *, + prefix: W, + config: dict, + hidden_size: int, + source_prefix: W | None = None, +) -> ExprPlan: + """Plan for Kimi Delta Attention (KDA) mixer. + + Weight schema: + - q_proj.weight, k_proj.weight, v_proj.weight: (projection_size, hidden_size) + - o_proj.weight: (hidden_size, projection_size) + - q_conv.weight, k_conv.weight, v_conv.weight: (projection_size, 1, kernel_size) + - f_a_proj.weight: (head_dim, hidden_size) + - f_b_proj.weight: (projection_size, head_dim) + - g_a_proj.weight: (head_dim, hidden_size) + - g_b_proj.weight: (projection_size, head_dim) + - beta_proj.weight: (num_heads, hidden_size) + - A_log: (num_heads,) + - dt_bias: (projection_size,) + - norm.weight: (head_dim,) + + Args: + prefix: Target weight path prefix. + config: Mixer config dict. + hidden_size: Model hidden size. + source_prefix: If provided, passthrough from source. If None, random init. + """ + if source_prefix is not None: + # Passthrough + return ExprPlan(mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "q_proj.weight", + "k_proj.weight", + "v_proj.weight", + "o_proj.weight", + "q_conv.weight", + "k_conv.weight", + "v_conv.weight", + "f_a_proj.weight", + "f_b_proj.weight", + "g_a_proj.weight", + "g_b_proj.weight", + "beta_proj.weight", + "A_log", + "dt_bias", + "norm.weight", + ] + }) + + # Random init + num_heads = config["heads"] + head_dim = config["head_dim"] + projection_size = num_heads * head_dim + conv_kernel_size = config.get("convolution_layer", {}).get("kernel_size", 4) + + return ExprPlan(mappings={ + # Main projections + prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"), + # Convolutions + prefix / "q_conv" / "weight": Init( + shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" + ), + prefix / "k_conv" / "weight": Init( + shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" + ), + prefix / "v_conv" / "weight": Init( + shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" + ), + # Gate kernels (low-rank factorization) + prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Output gate (low-rank factorization) + prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Beta projection + prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), + # Learnable parameters + prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), + prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), + # Normalization + prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), + }) + + +# Dispatcher for per-mixer plan functions +_MIXER_PLANNERS = { + "attention": _plan_attention_mixer, + "sliding_window": _plan_attention_mixer, + "mamba": _plan_mamba_mixer, + "gdn": _plan_gdn_mixer, + "kda": _plan_kda_mixer, +} + +# Types that are attention-like (can be source for MIL/DIL/KIL) +_ATTENTION_TYPES = frozenset({"attention", "sliding_window"}) + + +# ============================================================================= +# SECTION 2: Cross-Type Converters (attention → X) +# ============================================================================= +# These are public functions for converting from attention to other mixer types. +# They handle the complex logic of slicing/tiling attention weights. + + def plan_mil_attention_to_mamba( - layer_idx: int, + *, hidden_size: int, d_inner: int, d_xb: int, @@ -85,19 +394,31 @@ def plan_mil_attention_to_mamba( source_prefix: W, target_prefix: W, ) -> ExprPlan: - """MIL: Q→C, K→B, V→x, O→out_proj, z/conv/dt/A_log/D→random.""" - # in_proj layout: [z, x, B, C] sizes [d_inner, d_xb, d_xb, d_inner] + """MIL: Mamba Initialization from LLM. + + Converts attention → mamba by mapping: + - Q → C (readout) + - K → B (input-dependent state transition) + - V → x (input) + - O → out_proj + - z, conv1d, dt_proj, A_log, D → random initialization + + in_proj layout: [z, x, B, C] with sizes [d_inner, d_xb, d_xb, d_inner] + """ in_proj_expr = Concat( exprs=( Init(shape=(d_inner, hidden_size), init_type="kaiming"), # z: random Slice( - expr=Ref(key=source_prefix / "v_proj" / "weight"), slices=((0, d_xb, None), (None, None, None)) + expr=Ref(key=source_prefix / "v_proj" / "weight"), + slices=((0, d_xb, None), (None, None, None)) ), # x <- V Slice( - expr=Ref(key=source_prefix / "k_proj" / "weight"), slices=((0, d_xb, None), (None, None, None)) + expr=Ref(key=source_prefix / "k_proj" / "weight"), + slices=((0, d_xb, None), (None, None, None)) ), # B <- K Slice( - expr=Ref(key=source_prefix / "q_proj" / "weight"), slices=((0, d_inner, None), (None, None, None)) + expr=Ref(key=source_prefix / "q_proj" / "weight"), + slices=((0, d_inner, None), (None, None, None)) ), # C <- Q ), dim=0, @@ -105,7 +426,7 @@ def plan_mil_attention_to_mamba( conv_channels = d_inner if repeat_kv_before_conv else d_xb - result = { + mappings: dict[W, Expr] = { target_prefix / "in_proj" / "weight": in_proj_expr, target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), target_prefix / "dt_in_proj" / "weight": Init(shape=(dt_rank, hidden_size), init_type="kaiming"), @@ -116,19 +437,19 @@ def plan_mil_attention_to_mamba( } if dt_bias: - result[target_prefix / "dt_proj" / "bias"] = Init( + mappings[target_prefix / "dt_proj" / "bias"] = Init( shape=(d_inner,), init_type="dt_bias", init_params={"dt_min": dt_min, "dt_max": dt_max, "dt_init_floor": dt_init_floor}, ) if conv_bias: - result[target_prefix / "conv1d" / "bias"] = Init(shape=(conv_channels,), init_type="zeros") + mappings[target_prefix / "conv1d" / "bias"] = Init(shape=(conv_channels,), init_type="zeros") - return ExprPlan(mappings=result) + return ExprPlan(mappings=mappings) -def plan_attention_to_gated_delta_net( +def plan_dil_attention_to_gdn( *, hidden_size: int, num_v_heads: int, @@ -142,7 +463,10 @@ def plan_attention_to_gated_delta_net( source_prefix: W, target_prefix: W, ) -> ExprPlan: - """DIL: Q/K/V→in_proj_qkvz (tiled for GQA), O→out_proj, Z/ba/conv/A_log/dt_bias/norm→init. + """DIL: Delta-net Initialization from LLM. + + Converts attention → gated_delta_net by mapping Q/K/V/O projections + to the fused in_proj_qkvz and out_proj, respecting GQA head grouping. Produces FLAT layout for in_proj_qkvz: [Q_all | K_all | V_all | Z_all] This matches Apriel2/Fast-LLM's expected layout. @@ -157,7 +481,6 @@ def plan_attention_to_gated_delta_net( v_ref = Ref(key=source_prefix / "v_proj" / "weight") # Build FLAT layout: [Q_all | K_all | V_all | Z_all] - # Collect slices for each projection type across all heads q_slices: list[Expr] = [] k_slices: list[Expr] = [] v_slices: list[Expr] = [] @@ -209,28 +532,306 @@ def plan_attention_to_gated_delta_net( dim=0, ) - # BA uses flat layout: [b_all | a_all] - in_proj_ba_expr = Init(shape=(2 * num_v_heads, hidden_size), init_type="zeros") # b=a=0 → β=0.5 - out_proj_expr = Ref(key=source_prefix / "o_proj" / "weight") - conv_weight_expr = Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv") - A_log_expr = Init(shape=(num_v_heads,), init_type="slow_decay") - dt_bias_expr = Init(shape=(num_v_heads,), init_type="zeros") - norm_weight_expr = Init(shape=(head_v_dim,), init_type="ones") + return ExprPlan(mappings={ + target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr, + target_prefix / "in_proj_ba" / "weight": Init( + shape=(2 * num_v_heads, hidden_size), init_type="zeros" + ), # b=a=0 → β=0.5 + target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), + target_prefix / "convolution" / "weight": Init( + shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv" + ), + target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), + target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), + target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), + }) + + +def plan_kil_attention_to_kda( + *, + hidden_size: int, + num_heads: int, + head_dim: int, + conv_kernel_size: int, + source_num_q_heads: int, + source_num_kv_heads: int, + source_head_dim: int, + source_prefix: W, + target_prefix: W, +) -> ExprPlan: + """KIL: Kimi Initialization from LLM. + + Converts attention → KDA by transferring Q/K/V/O projections directly. + Gates, convolutions, and learnable parameters are randomly initialized. + + Transfer (with GQA tiling if needed): + - q_proj: Transfer from attention.q_proj + - k_proj: Transfer from attention.k_proj (tiled if GQA) + - v_proj: Transfer from attention.v_proj (tiled if GQA) + - o_proj: Transfer from attention.o_proj + + Random init (no attention analogue): + - f_a_proj, f_b_proj: Gate kernel (low-rank factorization) + - g_a_proj, g_b_proj: Output gate (low-rank factorization) + - beta_proj: Per-head beta gating + - q_conv, k_conv, v_conv: Causal convolutions (scaled identity) + - A_log: State matrix log (slow decay) + - dt_bias: Time step bias (zeros) + - norm: Gated RMS normalization (ones) + """ + projection_size = num_heads * head_dim + source_q_size = source_num_q_heads * source_head_dim + source_kv_size = source_num_kv_heads * source_head_dim + + q_ref = Ref(key=source_prefix / "q_proj" / "weight") + k_ref = Ref(key=source_prefix / "k_proj" / "weight") + v_ref = Ref(key=source_prefix / "v_proj" / "weight") + + # Q: tile source Q heads to fill target projection_size + if source_q_size == projection_size: + q_expr: Expr = q_ref + else: + q_slices: list[Expr] = [] + for h in range(num_heads): + src_h = h % source_num_q_heads + row_start = src_h * source_head_dim + q_slices.append( + Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) + ) + q_expr = Concat(exprs=tuple(q_slices), dim=0) + + # K: tile source KV heads to fill target projection_size + if source_kv_size == projection_size: + k_expr: Expr = k_ref + else: + k_slices: list[Expr] = [] + for h in range(num_heads): + src_h = h % source_num_kv_heads + row_start = src_h * source_head_dim + k_slices.append( + Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) + ) + k_expr = Concat(exprs=tuple(k_slices), dim=0) + + # V: tile source KV heads to fill target projection_size + if source_kv_size == projection_size: + v_expr: Expr = v_ref + else: + v_slices: list[Expr] = [] + for h in range(num_heads): + src_h = h % source_num_kv_heads + row_start = src_h * source_head_dim + v_slices.append( + Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) + ) + v_expr = Concat(exprs=tuple(v_slices), dim=0) + + return ExprPlan(mappings={ + # Transfer main projections + target_prefix / "q_proj" / "weight": q_expr, + target_prefix / "k_proj" / "weight": k_expr, + target_prefix / "v_proj" / "weight": v_expr, + target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), + # Random init: convolutions (scaled identity for near-passthrough initially) + target_prefix / "q_conv" / "weight": Init( + shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" + ), + target_prefix / "k_conv" / "weight": Init( + shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" + ), + target_prefix / "v_conv" / "weight": Init( + shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" + ), + # Random init: gate kernels (low-rank factorization) + target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Random init: output gate (low-rank factorization) + target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Random init: beta projection + target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), + # Random init: learnable parameters + target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), + target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), + # Random init: normalization + target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), + }) + + +# ============================================================================= +# SECTION 3: Dispatch Logic +# ============================================================================= + + +def _plan_mixer_transfer( + source_type: str, + target_type: str, + source_config: dict, + target_config: dict, + source_prefix: W, + target_prefix: W, + hidden_size: int, +) -> ExprPlan: + """Transfer weights between mixer types. + + For same-type transfers, uses passthrough via per-mixer plan functions. + For cross-type transfers, dispatches to MIL/DIL/KIL converters. + Raises ValueError if no converter exists for the type pair. + """ + # Same-type: passthrough via unified per-mixer function + if source_type == target_type: + planner = _MIXER_PLANNERS.get(target_type) + if planner is not None: + return planner( + prefix=target_prefix, + config=target_config, + hidden_size=hidden_size, + source_prefix=source_prefix, + ) + + # Attention variants are interchangeable + if source_type in _ATTENTION_TYPES and target_type in _ATTENTION_TYPES: + return _plan_attention_mixer( + prefix=target_prefix, + config=target_config, + hidden_size=hidden_size, + source_prefix=source_prefix, + ) + + # Attention → Mamba (MIL) + if source_type in _ATTENTION_TYPES and target_type == "mamba": + return plan_mil_attention_to_mamba( + hidden_size=hidden_size, + d_inner=target_config.get("d_inner", 2 * hidden_size), + d_xb=target_config.get("d_xb", hidden_size // 4), + dt_rank=target_config.get("dt_rank", hidden_size // 16), + d_state=target_config["d_state"], + d_conv=target_config["d_conv"], + repeat_kv_before_conv=target_config["repeat_kv_before_conv"], + conv_bias=target_config["conv_bias"], + dt_bias=target_config["dt_proj_bias"], + dt_min=target_config["dt_min"], + dt_max=target_config["dt_max"], + dt_init_floor=target_config["dt_init_floor"], + source_prefix=source_prefix, + target_prefix=target_prefix, + ) + + # Attention → GatedDeltaNet (DIL) + if source_type in _ATTENTION_TYPES and target_type == "gdn": + source_heads = source_config["heads"] + source_kv_heads = source_config["head_groups"] + source_head_size = source_config["head_size"] + + return plan_dil_attention_to_gdn( + hidden_size=hidden_size, + num_v_heads=target_config.get("value_heads", source_heads), + num_k_heads=target_config.get("key_heads", source_kv_heads), + head_k_dim=target_config.get("key_head_dim", source_head_size), + head_v_dim=target_config.get("value_head_dim", source_head_size), + conv_kernel_size=target_config["convolution_layer"]["kernel_size"], + source_num_q_heads=source_heads, + source_num_kv_heads=source_kv_heads, + source_head_dim=source_head_size, + source_prefix=source_prefix, + target_prefix=target_prefix, + ) + + # Attention → KDA (KIL) + if source_type in _ATTENTION_TYPES and target_type == "kda": + source_heads = source_config["heads"] + source_kv_heads = source_config["head_groups"] + source_head_size = source_config["head_size"] + + return plan_kil_attention_to_kda( + hidden_size=hidden_size, + num_heads=target_config.get("heads", source_heads), + head_dim=target_config.get("head_dim", source_head_size), + conv_kernel_size=target_config.get("convolution_layer", {}).get("kernel_size", 4), + source_num_q_heads=source_heads, + source_num_kv_heads=source_kv_heads, + source_head_dim=source_head_size, + source_prefix=source_prefix, + target_prefix=target_prefix, + ) + + raise ValueError( + f"No converter available for {source_type} -> {target_type}. " + f"Use 'init: random' to initialize randomly, or implement a converter." + ) + + +def _plan_random_mixer( + prefix: W, + mixer_type: str, + config: dict, + hidden_size: int, +) -> ExprPlan: + """Random initialization for any mixer type. + + Dispatches to the per-mixer plan function with source_prefix=None. + """ + planner = _MIXER_PLANNERS.get(mixer_type) + if planner is None: + raise ValueError(f"Unknown mixer type: {mixer_type}") + return planner(prefix=prefix, config=config, hidden_size=hidden_size, source_prefix=None) + + +# ============================================================================= +# SECTION 4: Main Entry Point +# ============================================================================= + + +def plan_surgery( + source_config: dict, + target_config: dict, +) -> ExprPlan: + """Build plan for Apriel2→Apriel2 surgery (MIL, DIL, KIL, stochastic mixers, etc.).""" + hidden_size = target_config.get("hidden_size", source_config.get("hidden_size")) + assert hidden_size is not None, "hidden_size must be specified in source or target config" + + source_decoder = source_config.get("decoder", {}) + target_decoder = target_config.get("decoder", {}) + + num_source_layers = source_decoder.get("num_blocks", 0) + num_target_layers = target_decoder.get("num_blocks", num_source_layers) + + plan = _plan_non_decoder_weights(source_config) + + for target_layer_idx in range(num_target_layers): + source_layer_idx = target_layer_idx % num_source_layers if num_source_layers > 0 else 0 + source_block = _get_block_config(source_decoder, source_layer_idx) + target_block = _get_block_config(target_decoder, target_layer_idx) + + plan += _plan_mixer( + target_layer_idx, source_layer_idx, + source_block.get("mixer", {}), target_block.get("mixer", {}), + hidden_size, + ) + plan += _plan_mlp( + target_layer_idx, source_layer_idx, + source_block.get("mlp", {}), target_block.get("mlp", {}), + hidden_size, + ) + plan += _plan_norms( + target_layer_idx, source_layer_idx, + source_block, target_block, + hidden_size, + ) - # Apriel2GatedDeltaNet is now inlined (no .gdn wrapper), uses 'convolution' to match Fast-LLM return ExprPlan( - mappings={ - target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr, - target_prefix / "in_proj_ba" / "weight": in_proj_ba_expr, - target_prefix / "out_proj" / "weight": out_proj_expr, - target_prefix / "convolution" / "weight": conv_weight_expr, - target_prefix / "A_log": A_log_expr, - target_prefix / "dt_bias": dt_bias_expr, - target_prefix / "norm" / "weight": norm_weight_expr, - } + mappings=plan.mappings, + source_format="apriel2", + target_format="apriel2", + metadata=plan.metadata, ) +# ============================================================================= +# SECTION 5: Non-Mixer Helpers +# ============================================================================= + + def _plan_non_decoder_weights(config: dict) -> ExprPlan: """Passthrough for embeddings, lm_head, final norm, vision encoder.""" mappings: dict[W, Expr] = {} @@ -298,51 +899,6 @@ def _get_block_config(decoder_config: dict, layer_idx: int) -> dict: return {} -def plan_surgery( - source_config: dict, - target_config: dict, -) -> ExprPlan: - """Build plan for Apriel2→Apriel2 surgery (MIL, DIL, stochastic mixers, etc.).""" - hidden_size = target_config.get("hidden_size", source_config.get("hidden_size")) - assert hidden_size is not None, "hidden_size must be specified in source or target config" - - source_decoder = source_config.get("decoder", {}) - target_decoder = target_config.get("decoder", {}) - - num_source_layers = source_decoder.get("num_blocks", 0) - num_target_layers = target_decoder.get("num_blocks", num_source_layers) - - plan = _plan_non_decoder_weights(source_config) - - for target_layer_idx in range(num_target_layers): - source_layer_idx = target_layer_idx % num_source_layers if num_source_layers > 0 else 0 - source_block = _get_block_config(source_decoder, source_layer_idx) - target_block = _get_block_config(target_decoder, target_layer_idx) - - plan += _plan_mixer( - target_layer_idx, source_layer_idx, - source_block.get("mixer", {}), target_block.get("mixer", {}), - hidden_size, - ) - plan += _plan_mlp( - target_layer_idx, source_layer_idx, - source_block.get("mlp", {}), target_block.get("mlp", {}), - hidden_size, - ) - plan += _plan_norms( - target_layer_idx, source_layer_idx, - source_block, target_block, - hidden_size, - ) - - return ExprPlan( - mappings=plan.mappings, - source_format="apriel2", - target_format="apriel2", - metadata=plan.metadata, - ) - - def _plan_mixer( target_layer_idx: int, source_layer_idx: int, @@ -350,6 +906,7 @@ def _plan_mixer( target_mixer: dict, hidden_size: int, ) -> ExprPlan: + """Plan mixer weights, handling stochastic wrapper routing.""" source_type = source_mixer.get("type", "attention") target_type = target_mixer.get("type", source_type) @@ -429,200 +986,6 @@ def _plan_mixer( ) -def _plan_mixer_transfer( - source_type: str, - target_type: str, - source_config: dict, - target_config: dict, - source_prefix: W, - target_prefix: W, - hidden_size: int, -) -> ExprPlan: - """Transfer weights. Raises ValueError if no converter for this type pair.""" - # Attention → Attention - if source_type in ("attention", "sliding_window") and target_type in ("attention", "sliding_window"): - return ExprPlan( - mappings={ - target_prefix / proj / "weight": Ref(key=source_prefix / proj / "weight") - for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] - } - ) - - # Attention → Mamba (MIL) - if source_type in ("attention", "sliding_window") and target_type == "mamba": - d_inner = target_config.get("d_inner", 2 * hidden_size) - dt_rank = target_config.get("dt_rank", hidden_size // 16) - d_xb = target_config.get("d_xb", hidden_size // 4) - d_state = target_config["d_state"] - d_conv = target_config["d_conv"] - repeat_kv_before_conv = target_config["repeat_kv_before_conv"] - conv_bias = target_config["conv_bias"] - dt_bias = target_config["dt_proj_bias"] - dt_min = target_config["dt_min"] - dt_max = target_config["dt_max"] - dt_init_floor = target_config["dt_init_floor"] - - return plan_mil_attention_to_mamba( - layer_idx=0, - hidden_size=hidden_size, - d_inner=d_inner, - d_xb=d_xb, - dt_rank=dt_rank, - d_state=d_state, - d_conv=d_conv, - repeat_kv_before_conv=repeat_kv_before_conv, - conv_bias=conv_bias, - dt_bias=dt_bias, - dt_min=dt_min, - dt_max=dt_max, - dt_init_floor=dt_init_floor, - source_prefix=source_prefix, - target_prefix=target_prefix, - ) - - # Mamba → Mamba - if source_type == "mamba" and target_type == "mamba": - return ExprPlan( - mappings={ - target_prefix / name: Ref(key=source_prefix / name) - for name in [ - "in_proj.weight", - "out_proj.weight", - "dt_in_proj.weight", - "dt_proj.weight", - "dt_proj.bias", - "conv1d.weight", - "conv1d.bias", - "A_log", - "D", - ] - } - ) - - # Attention → GatedDeltaNet (DIL) - if source_type in ("attention", "sliding_window") and target_type == "gdn": - source_heads = source_config["heads"] - source_kv_heads = source_config["head_groups"] - source_head_size = source_config["head_size"] - num_v_heads = target_config.get("value_heads", source_heads) - num_k_heads = target_config.get("key_heads", source_kv_heads) - head_k_dim = target_config.get("key_head_dim", source_head_size) - head_v_dim = target_config.get("value_head_dim", source_head_size) - conv_kernel_size = target_config["convolution_layer"]["kernel_size"] - - return plan_attention_to_gated_delta_net( - hidden_size=hidden_size, - num_v_heads=num_v_heads, - num_k_heads=num_k_heads, - head_k_dim=head_k_dim, - head_v_dim=head_v_dim, - conv_kernel_size=conv_kernel_size, - source_num_q_heads=source_heads, - source_num_kv_heads=source_kv_heads, - source_head_dim=source_head_size, - source_prefix=source_prefix, - target_prefix=target_prefix, - ) - - # GatedDeltaNet → GatedDeltaNet (no .gdn wrapper, uses 'convolution' to match Fast-LLM) - if source_type == "gdn" and target_type == "gdn": - return ExprPlan( - mappings={ - target_prefix / name: Ref(key=source_prefix / name) - for name in [ - "in_proj_qkvz.weight", - "in_proj_ba.weight", - "out_proj.weight", - "convolution.weight", - "A_log", - "dt_bias", - "norm.weight", - ] - } - ) - - raise ValueError( - f"No converter available for {source_type} -> {target_type}. " - f"Use 'init: random' to initialize randomly, or implement a converter." - ) - - -def _plan_random_mixer( - prefix: W, - mixer_type: str, - config: dict, - hidden_size: int, -) -> ExprPlan: - mappings: dict[W, Expr] = {} - - if mixer_type in ("attention", "sliding_window"): - heads = config["heads"] - head_groups = config["head_groups"] - head_size = config["head_size"] - q_size = heads * head_size - kv_size = head_groups * head_size - - mappings[prefix / "q_proj" / "weight"] = Init(shape=(q_size, hidden_size), init_type="kaiming") - mappings[prefix / "k_proj" / "weight"] = Init(shape=(kv_size, hidden_size), init_type="kaiming") - mappings[prefix / "v_proj" / "weight"] = Init(shape=(kv_size, hidden_size), init_type="kaiming") - mappings[prefix / "o_proj" / "weight"] = Init(shape=(hidden_size, q_size), init_type="kaiming") - - elif mixer_type == "mamba": - d_inner = config["d_inner"] - d_state = config["d_state"] - dt_rank = config["dt_rank"] - d_xb = config["d_xb"] - d_conv = config["d_conv"] - repeat_kv_before_conv = config["repeat_kv_before_conv"] - conv_bias = config["conv_bias"] - dt_bias = config["dt_proj_bias"] - dt_min = config["dt_min"] - dt_max = config["dt_max"] - dt_init_floor = config["dt_init_floor"] - - conv_channels = d_inner if repeat_kv_before_conv else d_xb - mappings[prefix / "in_proj" / "weight"] = Init( - shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming" - ) - mappings[prefix / "out_proj" / "weight"] = Init(shape=(hidden_size, d_inner), init_type="kaiming") - mappings[prefix / "dt_in_proj" / "weight"] = Init(shape=(dt_rank, hidden_size), init_type="kaiming") - mappings[prefix / "dt_proj" / "weight"] = Init(shape=(d_inner, dt_rank), init_type="kaiming") - mappings[prefix / "conv1d" / "weight"] = Init(shape=(conv_channels, 1, d_conv), init_type="kaiming") - if conv_bias: - mappings[prefix / "conv1d" / "bias"] = Init(shape=(conv_channels,), init_type="zeros") - if dt_bias: - mappings[prefix / "dt_proj" / "bias"] = Init( - shape=(d_inner,), - init_type="dt_bias", - init_params={"dt_min": dt_min, "dt_max": dt_max, "dt_init_floor": dt_init_floor}, - ) - mappings[prefix / "A_log"] = Init(shape=(d_inner, d_state), init_type="s4d") - mappings[prefix / "D"] = Init(shape=(d_inner,), init_type="ones") - - elif mixer_type == "gdn": - num_v_heads = config["value_heads"] - num_k_heads = config["key_heads"] - head_k_dim = config["key_head_dim"] - head_v_dim = config["value_head_dim"] - conv_kernel_size = config["convolution_layer"]["kernel_size"] - key_dim = head_k_dim * num_k_heads - value_dim = head_v_dim * num_v_heads - conv_dim = key_dim * 2 + value_dim - # No .gdn wrapper, uses 'convolution' to match Fast-LLM naming - qkvz_size = key_dim * 2 + value_dim * 2 # Q, K both key_dim; V, Z both value_dim - mappings[prefix / "in_proj_qkvz" / "weight"] = Init(shape=(qkvz_size, hidden_size), init_type="kaiming") - mappings[prefix / "in_proj_ba" / "weight"] = Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros") - mappings[prefix / "out_proj" / "weight"] = Init(shape=(hidden_size, value_dim), init_type="kaiming") - mappings[prefix / "convolution" / "weight"] = Init( - shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv" - ) - mappings[prefix / "A_log"] = Init(shape=(num_v_heads,), init_type="slow_decay") - mappings[prefix / "dt_bias"] = Init(shape=(num_v_heads,), init_type="zeros") - mappings[prefix / "norm" / "weight"] = Init(shape=(head_v_dim,), init_type="ones") - - return ExprPlan(mappings=mappings) - - def _plan_mlp( target_layer_idx: int, source_layer_idx: int, @@ -630,6 +993,7 @@ def _plan_mlp( target_mlp: dict, hidden_size: int, ) -> ExprPlan: + """Plan MLP weights.""" if target_mlp.get("init") == "random": return _plan_random_mlp(target_layer_idx, target_mlp, hidden_size) return _plan_mlp_transfer(target_layer_idx, source_layer_idx, source_mlp, target_mlp, hidden_size) @@ -642,6 +1006,7 @@ def _plan_mlp_transfer( target_mlp: dict, hidden_size: int, ) -> ExprPlan: + """Passthrough for MLP weights.""" source_mlp_path = W("model", "decoder", "blocks", source_layer_idx, "mlp") target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp") @@ -654,12 +1019,10 @@ def _plan_mlp_transfer( f"Use 'init: random' to initialize randomly." ) - mappings: dict[W, Expr] = { + return ExprPlan(mappings={ target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") for proj in ["gate_proj", "up_proj", "down_proj"] - } - - return ExprPlan(mappings=mappings) + }) def _plan_random_mlp( @@ -667,12 +1030,19 @@ def _plan_random_mlp( target_mlp: dict, hidden_size: int, ) -> ExprPlan: + """Random initialization for MLP.""" target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp") intermediate_size = target_mlp["intermediate_size"] return ExprPlan(mappings={ - target_mlp_path / "gate_proj" / "weight": Init(shape=(intermediate_size, hidden_size), init_type="kaiming"), - target_mlp_path / "up_proj" / "weight": Init(shape=(intermediate_size, hidden_size), init_type="kaiming"), - target_mlp_path / "down_proj" / "weight": Init(shape=(hidden_size, intermediate_size), init_type="kaiming"), + target_mlp_path / "gate_proj" / "weight": Init( + shape=(intermediate_size, hidden_size), init_type="kaiming" + ), + target_mlp_path / "up_proj" / "weight": Init( + shape=(intermediate_size, hidden_size), init_type="kaiming" + ), + target_mlp_path / "down_proj" / "weight": Init( + shape=(hidden_size, intermediate_size), init_type="kaiming" + ), }) @@ -683,6 +1053,7 @@ def _plan_norms( target_block: dict, hidden_size: int, ) -> ExprPlan: + """Plan normalization layer weights.""" target_norm = target_block.get("normalization", {}) if target_norm.get("init") == "random": return _plan_random_norms(target_layer_idx, hidden_size) @@ -696,6 +1067,7 @@ def _plan_norms_transfer( target_block: dict, hidden_size: int, ) -> ExprPlan: + """Passthrough for normalization layer weights.""" source_layer = W("model", "decoder", "blocks", source_layer_idx) target_layer = W("model", "decoder", "blocks", target_layer_idx) @@ -711,18 +1083,17 @@ def _plan_norms_transfer( f"Use 'init: random' to initialize randomly." ) - mappings: dict[W, Expr] = { + return ExprPlan(mappings={ target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight") for norm_name in ["input_layernorm", "post_attention_layernorm"] - } - - return ExprPlan(mappings=mappings) + }) def _plan_random_norms( target_layer_idx: int, hidden_size: int, ) -> ExprPlan: + """Random initialization for normalization layers.""" target_layer = W("model", "decoder", "blocks", target_layer_idx) return ExprPlan(mappings={ target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones") diff --git a/fast_llm_external_models/apriel2/examples/comprehensive.yaml b/fast_llm_external_models/apriel2/examples/comprehensive.yaml index ceed2fe6f..b609fccb2 100644 --- a/fast_llm_external_models/apriel2/examples/comprehensive.yaml +++ b/fast_llm_external_models/apriel2/examples/comprehensive.yaml @@ -7,8 +7,10 @@ # - Pure sliding window attention (transfer with window override) # - Pure mamba (MIL conversion from attention) # - Pure gdn (DIL conversion from attention) +# - Pure kda (KIL conversion from attention) # - Stochastic mixer: attention + mamba # - Stochastic mixer: swa + gdn +# - Stochastic mixer: attention + kda # # Usage: # python convert.py ServiceNow-AI/Apriel-1.5-15b-Thinker output/ \ @@ -24,44 +26,44 @@ decoder: - stoch_am # 3 - swa # 4 - stoch_sg # 5 - - gdn # 6 + - kda # 6 - attn # 7 - - stoch_sg # 8 + - stoch_ak # 8 - mamba # 9 - swa # 10 - stoch_am # 11 - gdn # 12 - - stoch_sg # 13 + - stoch_ak # 13 - attn # 14 - mamba # 15 - stoch_am # 16 - swa # 17 - - gdn # 18 + - kda # 18 - attn # 19 - stoch_sg # 20 - mamba # 21 - - stoch_am # 22 + - stoch_ak # 22 - swa # 23 - attn # 24 - gdn # 25 - - stoch_sg # 26 + - stoch_ak # 26 - mamba # 27 - swa # 28 - stoch_am # 29 - - gdn # 30 + - kda # 30 - attn # 31 - mamba # 32 - stoch_sg # 33 - swa # 34 - - stoch_am # 35 + - stoch_ak # 35 - attn # 36 - gdn # 37 - mamba # 38 - - stoch_sg # 39 + - stoch_ak # 39 - stoch_am # 40 - swa # 41 - attn # 42 - - gdn # 43 + - kda # 43 - mamba # 44 - stoch_sg # 45 - swa # 46 @@ -174,3 +176,38 @@ decoder: init: transfer normalization: init: transfer + + # Pure kimi delta attention - KIL conversion from attention + kda: + mixer: + type: kda + init: transfer # Uses KIL conversion + # Required param (cannot be derived) + convolution_layer: + kernel_size: 4 + # Optional - defaults derived from source attention if not specified + # heads: 32 # defaults to source heads + # head_dim: 160 # defaults to source head_size + mlp: + init: transfer + normalization: + init: transfer + + # Stochastic: attention + kimi delta attention + stoch_ak: + mixer: + type: stochastic + main_mixer_name: attention + mixers: + attention: + type: attention + init: transfer + kda: + type: kda + init: transfer # KIL + convolution_layer: + kernel_size: 4 + mlp: + init: transfer + normalization: + init: transfer diff --git a/fast_llm_external_models/apriel2/examples/hybrid_kil.yaml b/fast_llm_external_models/apriel2/examples/hybrid_kil.yaml new file mode 100644 index 000000000..162624d8c --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/hybrid_kil.yaml @@ -0,0 +1,96 @@ +# Example: Hybrid architecture with KIL conversion +# +# Converts attention-only model to a hybrid with: +# - First 8 layers: pure attention (keep for long-range) +# - Middle 32 layers: stochastic mixer with attention + kda (KIL converted) +# - Last 8 layers: pure attention (keep for output quality) +# +# The kda branches are initialized from attention weights via KIL. + +decoder: + type: pattern + # Pattern: 8x attention, then 32x stochastic, then 8x attention + # Total 48 layers for Apriel 1.5 + pattern: + - attn # 0 + - attn # 1 + - attn # 2 + - attn # 3 + - attn # 4 + - attn # 5 + - attn # 6 + - attn # 7 + - hybrid # 8 + - hybrid # 9 + - hybrid # 10 + - hybrid # 11 + - hybrid # 12 + - hybrid # 13 + - hybrid # 14 + - hybrid # 15 + - hybrid # 16 + - hybrid # 17 + - hybrid # 18 + - hybrid # 19 + - hybrid # 20 + - hybrid # 21 + - hybrid # 22 + - hybrid # 23 + - hybrid # 24 + - hybrid # 25 + - hybrid # 26 + - hybrid # 27 + - hybrid # 28 + - hybrid # 29 + - hybrid # 30 + - hybrid # 31 + - hybrid # 32 + - hybrid # 33 + - hybrid # 34 + - hybrid # 35 + - hybrid # 36 + - hybrid # 37 + - hybrid # 38 + - hybrid # 39 + - attn # 40 + - attn # 41 + - attn # 42 + - attn # 43 + - attn # 44 + - attn # 45 + - attn # 46 + - attn # 47 + + blocks: + attn: + # Pure attention - transfer weights directly + mixer: + type: attention + init: transfer + mlp: + init: transfer + normalization: + init: transfer + + hybrid: + # Stochastic mixer with attention (transferred) and kda (KIL) + mixer: + type: stochastic + main_mixer_name: attention + mixers: + attention: + type: attention + init: transfer + # Full attention for global context + kda: + type: kda + init: transfer # Uses KIL conversion from attention + convolution_layer: + kernel_size: 4 # required, no default + # KDA dimensions can be configured or derived from source + # heads: 32 # defaults to source heads + # head_dim: 128 # defaults to source head_size + mlp: + init: transfer + normalization: + init: transfer diff --git a/fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml b/fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml index 8894fd0fd..2f0ed6a5d 100644 --- a/fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml +++ b/fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml @@ -1,4 +1,4 @@ -# Example: Stochastic supernet with attention + sliding window + gated delta net +# Example: Stochastic supernet with attention + sliding window + gated delta net + kda # # Converts a homogeneous attention model to a stochastic supernet # where each layer can sample from multiple mixer types during training. @@ -7,6 +7,7 @@ # - Full attention (direct weight transfer) # - Sliding window attention (transfer with window size override) # - Gated delta net (DIL initialization from attention weights) +# - Kimi delta attention (KIL initialization from attention weights) # # Usage: # python convert.py ServiceNow-AI/Apriel-1.5-15b-Thinker output/ \ @@ -43,6 +44,16 @@ decoder: convolution_layer: kernel_size: 4 + # Kimi delta attention - KIL initialization maps Q/K/V/O -> KDA projections + # KDA dimensions are derived from source attention: + # heads <- heads (40 for Apriel 1.5) + # head_dim <- head_size (128 for Apriel 1.5) + kda: + type: kda + init: transfer + convolution_layer: + kernel_size: 4 + # MLP and normalization transfer from source mlp: init: transfer diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml index c7016b814..78c22e57f 100644 --- a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml +++ b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml @@ -1,19 +1,29 @@ # Training config for small Apriel2 stochastic supernet (single GPU) # # This config loads a converted Apriel2 model and trains it on multimodal data. +# The stochastic supernet includes attention, sliding window, gated delta net, and KDA mixers. +# Training uses activation-level distillation from a teacher model (attention-only) to guide +# the alternative mixers (GDN, KDA) to produce similar activations. # # Prerequisites: # -# 1. Convert a source model to Apriel2 format with reduced layers: +# 1. Convert the student model (stochastic supernet) with reduced layers: # (Note: multiple --surgery flags are composed left-to-right) # -# python -m fast_llm_external_models.apriel2.conversion.convert \ -# mistral-community/pixtral-12b \ +# python fast_llm_external_models/apriel2/convert.py \ +# ServiceNow-AI/Apriel-1.5-15b-Thinker \ # /tmp/apriel2-supernet-small \ # --surgery fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml \ # --surgery fast_llm_external_models/apriel2/examples/small.yaml # -# 2. Create a multimodal dataset with matching patch size (16x16): +# 2. Convert the teacher model (attention-only, same layer reduction): +# +# python fast_llm_external_models/apriel2/convert.py \ +# ServiceNow-AI/Apriel-1.5-15b-Thinker \ +# /tmp/apriel2-teacher-small \ +# --surgery fast_llm_external_models/apriel2/examples/small.yaml +# +# 3. Create a multimodal dataset with matching patch size (16x16): # # python -c " # from tests.utils.dataset import _get_test_dataset, DATASET_CACHE @@ -31,13 +41,51 @@ # ) # " # -# 3. Run training: +# 4. Run training: # # fast-llm train train_multimodal \ # -c fast_llm_external_models/apriel2/examples/train_supernet_small.yaml # # The trained model will be exported to: # /tmp/apriel2-supernet-small-trained/export/apriel2/{iteration}/ +# +# 5. Load and test the trained model, then switch mixers at runtime: +# +# python -c " +# import torch +# from transformers import AutoProcessor, AutoModelForImageTextToText +# +# # Load the trained Apriel2 VLM (includes stochastic supernet with KDA) +# model = AutoModelForImageTextToText.from_pretrained( +# '/tmp/apriel2-supernet-small-trained/export/apriel2/10', +# torch_dtype=torch.bfloat16, +# device_map='auto', +# trust_remote_code=True, +# ) +# processor = AutoProcessor.from_pretrained('ServiceNow-AI/Apriel-1.5-15b-Thinker') +# +# # Show available mixers in the stochastic supernet +# block = model.model.decoder.blocks[0] +# print(f'Available mixers: {list(block.mixer.mixers.keys())}') +# print(f'Current main mixer: {block.mixer.main_mixer_name}') +# +# # Switch all blocks to use KDA as the main mixer (used during inference) +# for block in model.model.decoder.blocks: +# block.mixer.main_mixer_name = 'kda' +# print(f'Switched to: {model.model.decoder.blocks[0].mixer.main_mixer_name}') +# +# # Generate with KDA +# chat = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Hello'}]}] +# inputs = processor.apply_chat_template( +# chat, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors='pt' +# ) +# inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} +# inputs.pop('token_type_ids', None) +# +# with torch.no_grad(): +# output_ids = model.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.7) +# print(processor.decode(output_ids[0], skip_special_tokens=True)) +# " # Load pretrained model pretrained: @@ -49,14 +97,42 @@ pretrained: # Model config (mostly loaded from pretrained, but we need to specify some fast-llm specific settings) model: base_model: + # Freeze all components except the mixer by setting lr_scale: 0 + # The mixer will train with the default learning rate (lr_scale: 1.0 implicitly) + decoder: + block: + mlp: + lr_scale: 0.0 # Freeze MLP + normalization: + lr_scale: 0.0 # Freeze layer norms (norm_1 and norm_2 in each block) + # Activation-level distillation: teach mixers to mimic teacher's attention outputs + distillation_model: teacher + activation_distillation_factor: 0.1 + embeddings: + lr_scale: 0.0 # Freeze word embeddings head: + lr_scale: 0.0 # Freeze output head (includes final norm) cross_entropy_implementation: torch + vision_encoder: + lr_scale: 0.0 # Freeze vision encoder multi_stage: zero_stage: 2 # ZeRO stage 2 for memory efficiency distributed: compute_dtype: bf16 seed: 42 +# Teacher model for activation-level distillation +# Uses the same architecture but with standard attention (no stochastic mixer) +reference_models: + teacher: + model: + type: multimodal + pretrained: + path: /tmp/apriel2-teacher-small + format: apriel2 + model_weights: true + load_config: model + # Batch configuration (small for single GPU) batch: sequence_length: 512 # Short sequences for testing @@ -82,18 +158,20 @@ optimizer: # Training configuration training: - train_iters: 10 # Just a few iterations for testing + train_iters: 100 # Extended training run num_workers: 2 logs: interval: 1 checkpoint: interval: null # Disable checkpointing for quick test export: - interval: 10 # Export at the end + interval: 100 # Export at the end format: apriel2 # Export back to Apriel2 HF format test_iters: 0 evaluators: {} -# Experiment directory +# Experiment directory and logging run: experiment_dir: /tmp/apriel2-supernet-small-trained + # Enable model debug logging to see stochastic mixer selection per iteration + model_debug_level: 1 diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 79028040d..4c263b4e2 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -38,6 +38,14 @@ except ImportError: rms_norm_gated = None +# KDA implementation - matches Fast-LLM's kda.py +try: + from fla.ops.kda import chunk_kda, fused_recurrent_kda + from fla.ops.kda.gate import fused_kda_gate +except ImportError: + chunk_kda = None + fused_recurrent_kda = None + fused_kda_gate = None is_fast_path_available = is_mamba_ssm_available() and is_causal_conv1d_available() @@ -87,15 +95,35 @@ def torch_causal_conv1d_fn(x, weight, bias=None, activation="silu"): @torch.compile def torch_causal_conv1d_update(x, conv_state, weight, bias=None, activation="silu"): + """ + Single-step causal convolution update. + + Args: + x: New input [batch, dim] + conv_state: Previous state [batch, dim, kernel_size-1], updated in-place + weight: Convolution kernel [dim, kernel_size] + bias: Optional bias [dim] + activation: Activation function name + + Returns: + Output [batch, dim] + """ assert activation == "silu", f"Only silu activation is supported, got {activation}" dtype = x.dtype - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) - conv_state[:, :, -1] = x - x = torch.sum(conv_state * weight.unsqueeze(0), dim=-1) + # Concatenate state with new input to get full kernel_size window + # conv_state: [batch, dim, kernel_size-1], x: [batch, dim] -> full: [batch, dim, kernel_size] + full_state = torch.cat([conv_state, x.unsqueeze(-1)], dim=-1) + + # Convolve: sum over last dimension + out = torch.sum(full_state * weight.unsqueeze(0), dim=-1) if bias is not None: - x = x + bias - return F.silu(x).to(dtype=dtype) + out = out + bias + + # Update state in-place: shift left and add new value + conv_state.copy_(full_state[:, :, 1:]) + + return F.silu(out).to(dtype=dtype) def torch_selective_scan_fn( @@ -109,16 +137,183 @@ def torch_selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias= if is_fast_path_available: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn + from causal_conv1d import causal_conv1d_update as _causal_conv1d_update from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update else: - causal_conv1d_fn = torch_causal_conv1d_fn - causal_conv1d_update = torch_causal_conv1d_update + _causal_conv1d_fn = None + _causal_conv1d_update = None selective_scan_fn = torch_selective_scan_fn selective_state_update = torch_selective_state_update +class CausalConv1d(nn.Conv1d): + """ + Causal 1D convolution that pads only on the left side. + + Subclasses nn.Conv1d for weight storage/checkpoint compatibility, but overrides + forward to use proper causal (left-only) padding instead of nn.Conv1d's symmetric padding. + + Supports: + - Prefill mode: process full sequence, optionally return final state for caching + - Decode mode: single-token update using cached conv state + - CUDA fast path (causal_conv1d library) with automatic CPU/fallback support + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + activation: str = "silu", + **kwargs, + ): + # Remove padding from kwargs since we handle it ourselves + kwargs.pop("padding", None) + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=0, # No built-in padding; we handle it in forward + **kwargs, + ) + self._activation = activation + + @property + def _weight(self) -> torch.Tensor: + """Weight in [dim, kernel_size] format for causal_conv1d functions.""" + return self.weight.squeeze(1) + + def _use_fast_path(self, x: torch.Tensor) -> bool: + """Check if we can use CUDA fast path.""" + return _causal_conv1d_fn is not None and x.device.type == "cuda" + + def forward( + self, + x: torch.Tensor, + conv_state: torch.Tensor | None = None, + return_final_state: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Apply causal convolution. + + Args: + x: Input tensor [batch, dim, seq_len] + conv_state: Previous conv state [batch, dim, kernel_size-1] for continuing + from cached state. If None, starts fresh. + return_final_state: If True, return (output, final_state) tuple where + final_state can be used for subsequent decode steps. + + Returns: + If return_final_state is False: output tensor [batch, dim, seq_len] + If return_final_state is True: (output, final_state) tuple + """ + batch_size, dim, seq_len = x.shape + + # CUDA kernel limitation: return_final_states requires channel-last layout, + # which is impossible to achieve when seq_len==1. Fall back to PyTorch. + use_fast_path = self._use_fast_path(x) and not (return_final_state and seq_len == 1) + + if use_fast_path: + # CUDA fast path + if return_final_state: + # causal_conv1d requires channel-last layout for returning final states. + # Channel-last means: stride(1)==1 AND stride(2)==dim (channels are contiguous). + # For shape [batch, dim, seq], standard contiguous is (dim*seq, seq, 1). + # Channel-last is (dim*seq, 1, dim) - achieved via transpose+contiguous+transpose. + if x.stride(1) != 1 or x.stride(2) < dim: + x = x.transpose(1, 2).contiguous().transpose(1, 2) + # Allocate final state buffer with correct memory layout + # causal_conv1d requires final_states.stride(1) == 1 + final_state = x.new_zeros(batch_size, self.kernel_size[0] - 1, dim).transpose(1, 2) + else: + final_state = None + + out = _causal_conv1d_fn( + x, + self._weight, + bias=self.bias, + initial_states=conv_state, + return_final_states=return_final_state, + final_states_out=final_state, + activation=self._activation, + ) + + if return_final_state: + if isinstance(out, tuple): + out, final_state = out + # Return a contiguous copy (still in channel-last layout) so callers can modify it in-place + # final_state has shape [batch, dim, state_len] with channel-last strides + # We need to preserve the channel-last layout for subsequent CUDA kernel calls + if final_state.stride(1) != 1: + # Already contiguous in channel-last + pass + else: + # Make a copy that's safe to modify in-place + final_state = final_state.clone() + return out, final_state + return out + else: + # PyTorch fallback + state_len = self.kernel_size[0] - 1 + + if conv_state is not None: + # Prepend state to input for proper convolution with history + x_with_state = torch.cat([conv_state, x], dim=-1) + out_with_state = torch_causal_conv1d_fn( + x_with_state, self._weight, bias=self.bias, activation=self._activation + ) + # Only keep outputs for the new input positions (not the state positions) + out = out_with_state[:, :, state_len:] + else: + out = torch_causal_conv1d_fn(x, self._weight, bias=self.bias, activation=self._activation) + + if return_final_state: + # Final state: last kernel_size-1 positions of input (with state if provided) + if conv_state is not None: + combined = torch.cat([conv_state, x], dim=-1) + final_state = combined[:, :, -state_len:].clone() + elif seq_len < state_len: + final_state = F.pad(x, (state_len - seq_len, 0)) + else: + final_state = x[:, :, -state_len:].clone() + return out, final_state + return out + + def update( + self, + x: torch.Tensor, + conv_state: torch.Tensor, + ) -> torch.Tensor: + """ + Single-token decode step using cached conv state. + + Args: + x: Input tensor [batch, dim] (single token) + conv_state: Conv state [batch, dim, kernel_size-1], will be updated in-place + + Returns: + Output tensor [batch, dim] + """ + if self._use_fast_path(x): + return _causal_conv1d_update( + x, + conv_state, + self._weight, + bias=self.bias, + activation=self._activation, + ) + else: + return torch_causal_conv1d_update( + x, + conv_state, + self._weight, + bias=self.bias, + activation=self._activation, + ) + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -407,8 +602,8 @@ def get_mixer_class(mixer_type: str) -> type: return Apriel2Mamba elif mixer_type == "gdn": return Apriel2GatedDeltaNet - elif mixer_type == "kimi_linear_attention": - return KimiLinearAttention + elif mixer_type == "kda": + return KimiDeltaAttention elif mixer_type == "stochastic": return Apriel2StochasticMixer else: @@ -429,7 +624,7 @@ def create_mixer(mixer_config: dict, hidden_size: int, layer_idx: int, config, a raise ValueError("Stochastic mixers cannot contain nested stochastic mixers") return mixer_class(mixer_config, config, layer_idx) else: - # mamba, gdn, kimi_linear_attention all have same signature + # mamba, gdn, kda all have same signature return mixer_class(hidden_size, mixer_config, layer_idx=layer_idx) @@ -476,29 +671,29 @@ def __init__( self.layer_idx = layer_idx self.repeat_kv_before_conv = repeat_kv_before_conv + self.activation = "silu" # Hardcoded for Mamba + if self.repeat_kv_before_conv: - self.conv1d = nn.Conv1d( + self.conv1d = CausalConv1d( in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, kernel_size=d_conv, groups=self.d_inner, - padding=d_conv - 1, + activation=self.activation, **factory_kwargs, ) else: - self.conv1d = nn.Conv1d( + self.conv1d = CausalConv1d( in_channels=self.d_xb, out_channels=self.d_xb, bias=conv_bias, kernel_size=d_conv, groups=self.d_xb, - padding=d_conv - 1, + activation=self.activation, **factory_kwargs, ) - self.activation = "silu" # Hardcoded for Mamba - self.num_xb_head = self.d_xb // self.d_state self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head @@ -607,16 +802,11 @@ def forward( x = repeat_kv(x, self.repeat_group) x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # Compute short convolution if conv_state is not None: + # Store padded input for future decode steps (convention: state size = d_conv) conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) - - # Compute short convolution - x = causal_conv1d_fn( - x=x, - weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), - bias=self.conv1d.bias, - activation=self.activation, - ) + x = self.conv1d(x) if not self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) @@ -691,13 +881,7 @@ def step(self, hidden_states, conv_state, ssm_state): x = rearrange(x, "b n_group dstate -> b (n_group dstate)") # Conv step - x = causal_conv1d_update( - x, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) + x = self.conv1d.update(x, conv_state) if not self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) @@ -843,12 +1027,18 @@ class GatedRMSNormalization(nn.Module): """ Gated RMS normalization layer matching Fast-LLM's implementation. Uses fla.modules.fused_norm_gate.rms_norm_gated when available. + + Args: + hidden_size: Size of the hidden dimension + eps: Epsilon for numerical stability + activation: Gating activation function ("silu" or "sigmoid") """ - def __init__(self, hidden_size: int, eps: float = 1e-5): + def __init__(self, hidden_size: int, eps: float = 1e-5, activation: str = "silu"): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps + self.activation = activation def forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: # Use PyTorch fallback on CPU since fla requires CUDA @@ -863,7 +1053,7 @@ def _forward_fla(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor gate, self.weight, None, - activation="silu", + activation=self.activation, eps=self.eps, residual=None, prenorm=False, @@ -877,7 +1067,11 @@ def _forward_local(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tens variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) hidden_states = self.weight * hidden_states.to(input_dtype) - return hidden_states * F.silu(gate) + # Apply gating with configured activation + if self.activation == "sigmoid": + return hidden_states * torch.sigmoid(gate) + else: # silu + return hidden_states * F.silu(gate) class Apriel2GatedDeltaNet(nn.Module): @@ -926,13 +1120,13 @@ def __init__( self.out_proj = nn.Linear(self.value_dim, d_model, bias=False, device=device, dtype=dtype) # Convolution - named 'convolution' to match Fast-LLM - self.convolution = nn.Conv1d( + self.convolution = CausalConv1d( in_channels=self.conv_dim, out_channels=self.conv_dim, bias=False, kernel_size=self.conv_kernel_size, groups=self.conv_dim, - padding=self.conv_kernel_size - 1, + activation=self.activation, device=device, dtype=dtype, ) @@ -1027,32 +1221,19 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m # Apply causal convolution if use_precomputed_states: - # Single token update - use cached conv state - # torch_causal_conv1d_update expects [batch, conv_dim] not [batch, conv_dim, 1] - mixed_qkv = torch_causal_conv1d_update( + # Single token decode - use cached conv state + mixed_qkv = self.convolution.update( mixed_qkv.squeeze(2), # [batch, conv_dim, 1] -> [batch, conv_dim] conv_state, - self.convolution.weight.squeeze(1), - None, # bias - "silu", - ).unsqueeze( - 2 - ) # [batch, conv_dim] -> [batch, conv_dim, 1] + ).unsqueeze(2) # [batch, conv_dim] -> [batch, conv_dim, 1] else: - # Prefill - store padded state for future decoding - if past_key_values is not None: - # Pad to kernel size and store for future decoding - padded = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - past_key_values.conv_states[self.layer_idx] = padded[:, :, -self.conv_kernel_size :] - # Apply convolution - # note, using F.silu(self.convolution(mixed_qkv)[:, :, :seq_len]) is numerically different than applying causal_conv1d_fn - # which failed the test test_fast_llm_gdn_matches_apriel2_forward - mixed_qkv = causal_conv1d_fn( - x=mixed_qkv, - weight=self.convolution.weight.squeeze(1), - bias=self.convolution.bias, - activation=self.activation, - ) + # Prefill mode + use_cache = past_key_values is not None + if use_cache: + mixed_qkv, final_state = self.convolution(mixed_qkv, return_final_state=True) + past_key_values.conv_states[self.layer_idx] = final_state + else: + mixed_qkv = self.convolution(mixed_qkv) mixed_qkv = mixed_qkv.transpose(1, 2) # [batch, seq, conv_dim] @@ -1156,8 +1337,22 @@ def preprocess( return {} -class KimiLinearAttention(nn.Module): - """KimiLinearAttention mixer - stub for future implementation.""" +class KimiDeltaAttention(nn.Module): + """ + Kimi Delta Attention (KDA) implementation matching Fast-LLM's kda.py. + + Weight names match Fast-LLM: + - q_proj, k_proj, v_proj, o_proj - main projections + - f_a_proj, f_b_proj - gate kernel (low-rank) + - g_a_proj, g_b_proj - output gate (low-rank) + - beta_proj - beta gating + - q_conv, k_conv, v_conv - CausalConv1d modules + - A_log, dt_bias - learnable parameters + - norm - gated RMS normalization + + Uses fla.ops.kda.chunk_kda and fused_recurrent_kda kernels. + Uses CausalConv1d for convolutions (CUDA fast path with PyTorch fallback). + """ def __init__( self, @@ -1168,7 +1363,205 @@ def __init__( dtype=None, ): super().__init__() - raise NotImplementedError("KimiLinearAttention not yet implemented in apriel2") + + if chunk_kda is None or fused_kda_gate is None: + raise ImportError( + "KimiDeltaAttention requires the `fla` package. " "Please install it with `pip install -U fla-core`." + ) + + self.layer_idx = layer_idx + self.hidden_size = d_model + self.mode = "chunk" + + # Config params - match Fast-LLM naming + self.num_heads = config_dict.get("heads", 32) + self.head_dim = config_dict.get("head_dim", 64) + conv_config = config_dict.get("convolution_layer", {}) + self.conv_kernel_size = conv_config.get("kernel_size", 4) + norm_config = config_dict.get("normalization", {}) + self.norm_eps = norm_config.get("epsilon", 1e-5) + self.norm_activation = norm_config.get("activation", "sigmoid") + + # Derived dimensions + self.projection_size = self.head_dim * self.num_heads + + # Projection layers - names match Fast-LLM exactly + self.q_proj = nn.Linear(d_model, self.projection_size, bias=False, device=device, dtype=dtype) + self.k_proj = nn.Linear(d_model, self.projection_size, bias=False, device=device, dtype=dtype) + self.v_proj = nn.Linear(d_model, self.projection_size, bias=False, device=device, dtype=dtype) + + # Convolutions - use CausalConv1d for proper left-only padding + # Named to match Fast-LLM (q_conv, k_conv, v_conv) + self.q_conv = CausalConv1d( + in_channels=self.projection_size, + out_channels=self.projection_size, + kernel_size=self.conv_kernel_size, + groups=self.projection_size, # depthwise + bias=False, + activation="silu", + device=device, + dtype=dtype, + ) + self.k_conv = CausalConv1d( + in_channels=self.projection_size, + out_channels=self.projection_size, + kernel_size=self.conv_kernel_size, + groups=self.projection_size, + bias=False, + activation="silu", + device=device, + dtype=dtype, + ) + self.v_conv = CausalConv1d( + in_channels=self.projection_size, + out_channels=self.projection_size, + kernel_size=self.conv_kernel_size, + groups=self.projection_size, + bias=False, + activation="silu", + device=device, + dtype=dtype, + ) + + # Gate kernel projections (low-rank: hidden -> head_dim -> projection) + self.f_a_proj = nn.Linear(d_model, self.head_dim, bias=False, device=device, dtype=dtype) + self.f_b_proj = nn.Linear(self.head_dim, self.projection_size, bias=False, device=device, dtype=dtype) + + # Output gate projections (low-rank) + self.g_a_proj = nn.Linear(d_model, self.head_dim, bias=False, device=device, dtype=dtype) + self.g_b_proj = nn.Linear(self.head_dim, self.projection_size, bias=False, device=device, dtype=dtype) + + # Beta projection - named beta_proj to match Fast-LLM (not b_proj) + self.beta_proj = nn.Linear(d_model, self.num_heads, bias=False, device=device, dtype=dtype) + + # Output projection + self.o_proj = nn.Linear(self.projection_size, d_model, bias=False, device=device, dtype=dtype) + + # Learnable parameters - match Fast-LLM shapes + # A_log: 1D shape (num_heads,) to match Fast-LLM + self.A_log = nn.Parameter( + torch.zeros(self.num_heads, device=device, dtype=torch.float32).uniform_(1, 16).log() + ) + self.dt_bias = nn.Parameter(torch.ones(self.projection_size, device=device, dtype=torch.float32)) + + # Normalization - use GatedRMSNormalization (same wrapper as GDN, with sigmoid activation) + self.norm = GatedRMSNormalization(self.head_dim, eps=self.norm_eps, activation=self.norm_activation) + + def _apply_conv( + self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool + ): + """ + Apply causal convolution with cache support. + + Args: + x: Input tensor [batch, seq, dim] + conv: CausalConv1d module + conv_state: Previous conv state [batch, dim, kernel_size-1] or None + use_cache: Whether to output final state for caching + + Returns: + (output, new_conv_state) tuple + """ + seq_len = x.shape[1] + x = x.transpose(1, 2) # [batch, dim, seq] + + # Single token decode with existing cache + if conv_state is not None and seq_len == 1: + out = conv.update(x.squeeze(2), conv_state) + return out.unsqueeze(1), conv_state # [batch, 1, dim] + + # Prefill mode + if use_cache: + out, final_state = conv(x, conv_state=conv_state, return_final_state=True) + else: + out = conv(x, conv_state=conv_state) + final_state = None + + return out.transpose(1, 2), final_state # [batch, seq, dim] + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values=None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + batch_size, seq_len, _ = hidden_states.shape + mode = "fused_recurrent" if seq_len <= 64 else self.mode + if self.training: + mode = "chunk" + + # Get cache states if available + conv_state_q, conv_state_k, conv_state_v = None, None, None + recurrent_state = None + use_cache = past_key_values is not None + + if past_key_values is not None: + conv_states = past_key_values.conv_states[self.layer_idx] + if conv_states is not None: + conv_state_q, conv_state_k, conv_state_v = conv_states + recurrent_state = past_key_values.recurrent_states[self.layer_idx] + + # Project Q, K, V and apply convolutions + q, conv_state_q = self._apply_conv(self.q_proj(hidden_states), self.q_conv, conv_state_q, use_cache) + k, conv_state_k = self._apply_conv(self.k_proj(hidden_states), self.k_conv, conv_state_k, use_cache) + v, conv_state_v = self._apply_conv(self.v_proj(hidden_states), self.v_conv, conv_state_v, use_cache) + + # Gate kernel computation + g = self.f_b_proj(self.f_a_proj(hidden_states)) + g = rearrange(g, "... (h d) -> ... h d", d=self.head_dim) + g = fused_kda_gate(g, self.A_log.float(), dt_bias=self.dt_bias) + + # Beta gating + beta = self.beta_proj(hidden_states).float().sigmoid() + + # Reshape Q, K, V to head format + q, k = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), (q, k)) + v = rearrange(v, "... (h d) -> ... h d", d=self.head_dim) + + # Run KDA kernel + if mode == "chunk": + o, recurrent_state = chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + ) + else: + o, recurrent_state = fused_recurrent_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if past_key_values is not None: + past_key_values.recurrent_states[self.layer_idx] = recurrent_state + past_key_values.conv_states[self.layer_idx] = (conv_state_q, conv_state_k, conv_state_v) + + # Output gating and normalization + g_out = self.g_b_proj(self.g_a_proj(hidden_states)) + g_out = rearrange(g_out, "... (h d) -> ... h d", d=self.head_dim) + + # Flatten for normalization, then reshape back + o_shape = o.shape + o = self.norm(o.reshape(-1, o.shape[-1]), g_out.reshape(-1, g_out.shape[-1])) + o = o.reshape(o_shape) + + # Reshape and project output + o = rearrange(o, "b t h d -> b t (h d)") + o = self.o_proj(o) + + return (o,) @classmethod def setup( @@ -1177,11 +1570,8 @@ def setup( hidden_size: int, max_position_embeddings: int, ) -> nn.ModuleDict: - """KimiLinearAttention setup not implemented.""" - raise NotImplementedError("KimiLinearAttention not yet implemented in apriel2") - - def forward(self, hidden_states: torch.Tensor, **kwargs): - raise NotImplementedError("KimiLinearAttention not yet implemented in apriel2") + """KimiDeltaAttention has no setup resources - returns empty ModuleDict.""" + return nn.ModuleDict() def preprocess( self, @@ -1189,8 +1579,8 @@ def preprocess( resources: Optional[nn.ModuleDict], **kwargs: Unpack[BlockSequenceKwargs], ) -> PreprocessingOutput: - """KimiLinearAttention preprocessing not implemented.""" - raise NotImplementedError("KimiLinearAttention not yet implemented in apriel2") + """KimiDeltaAttention has no preprocessing - returns empty dict.""" + return {} class Apriel2BlockSequence(nn.Module): diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py index 9473bd180..8585aec65 100644 --- a/fast_llm_external_models/tests/test_apriel2/conftest.py +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -496,6 +496,126 @@ def apriel2_config_all_mixers(): ) +@pytest.fixture +def apriel2_config_kda(): + """Apriel2 config with pure KDA (Kimi Delta Attention) layers. + + Tests KDA-specific cache behavior: + - Tuple conv states (q, k, v) instead of single tensor + - Recurrent state handling + """ + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "kda", + "heads": 4, + "head_dim": 16, + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + }, + "mlp": {"type": "mlp", "intermediate_size": 256, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def apriel2_config_all_mixers_with_kda(): + """Apriel2 config with all 5 mixer types including KDA. + + This config exercises: + - All mixer types (attention, swa, mamba, gdn, kda) + - KDA's tuple conv state handling in stochastic context + - Cache isolation between all mixer types + """ + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "pattern", + "num_blocks": 2, + "pattern": ["attn", "all_mixers"], + "blocks": { + "attn": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + }, + "mlp": {"type": "mlp", "intermediate_size": 256, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + "all_mixers": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + }, + "swa": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "window_size": 2048, + "rotary": {"type": "mistral_1d", "theta": 1000000.0}, + }, + "mamba": { + "type": "mamba", + "d_inner": 128, + "d_state": 16, + "dt_rank": 4, + "d_xb": 32, + "d_conv": 4, + "repeat_kv_before_conv": True, + "conv_bias": True, + "dt_proj_bias": True, + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init_floor": 1e-4, + }, + "gdn": { + "type": "gdn", + "value_heads": 4, + "key_heads": 2, + "key_head_dim": 16, + "value_head_dim": 16, + "convolution_layer": {"kernel_size": 4}, + }, + "kda": { + "type": "kda", + "heads": 4, + "head_dim": 16, + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + }, + }, + }, + "mlp": {"type": "mlp", "intermediate_size": 256, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + }, + ) + + @pytest.fixture def apriel2_config_comprehensive(): """Comprehensive Apriel2 config combining all features for thorough testing. @@ -750,7 +870,7 @@ def comprehensive_torture_chain(): This is the REAL stress test. It exercises: - Fixed → Pattern decoder transitions - Per-layer heterogeneity - - All type conversions: FA ↔ SWA ↔ Mamba ↔ GDN + - All type conversions: FA ↔ SWA ↔ Mamba ↔ GDN ↔ KDA - Stochastic wrapping/unwrapping - Both init: transfer and init: random - Destructive operations (remove sub-mixers, collapse stochastic) @@ -809,17 +929,17 @@ def comprehensive_torture_chain(): }, }, # ===================================================================== - # STEP 2: Add stochastic wrappers with MIL/DIL conversions + # STEP 2: Add stochastic wrappers with MIL/DIL/KIL conversions # Layer 0: stochastic{attn, mamba:MIL} # Layer 1: swa (unchanged) # Layer 2: stochastic{attn, gdn:DIL} # Layer 3: swa (unchanged) - # Layer 4: attn (unchanged) + # Layer 4: stochastic{attn, kda:KIL} # ===================================================================== { "decoder": { "type": "pattern", - "pattern": ["stoch_am", "swa", "stoch_ag", "swa", "attn"], + "pattern": ["stoch_am", "swa", "stoch_ag", "swa", "stoch_ak"], "blocks": { "stoch_am": { "mixer": { @@ -862,8 +982,20 @@ def comprehensive_torture_chain(): "mlp": {"init": "transfer"}, "normalization": {"init": "transfer"}, }, - "attn": { - "mixer": {"type": "attention", "init": "transfer"}, + "stoch_ak": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "init": "transfer"}, + "kda": { + "type": "kda", + "init": "transfer", # KIL conversion + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + }, + }, + }, "mlp": {"init": "transfer"}, "normalization": {"init": "transfer"}, }, @@ -876,12 +1008,12 @@ def comprehensive_torture_chain(): # Layer 1: mamba (MIL from swa!) # Layer 2: stoch{attn, gdn} (unchanged) # Layer 3: gdn (DIL from swa!) - # Layer 4: attn (unchanged) + # Layer 4: stoch{attn, kda} (unchanged) # ===================================================================== { "decoder": { "type": "pattern", - "pattern": ["stoch_am", "mamba", "stoch_ag", "gdn", "attn"], + "pattern": ["stoch_am", "mamba", "stoch_ag", "gdn", "stoch_ak"], "blocks": { "stoch_am": { "mixer": { @@ -929,8 +1061,20 @@ def comprehensive_torture_chain(): "mlp": {"init": "transfer"}, "normalization": {"init": "transfer"}, }, - "attn": { - "mixer": {"type": "attention", "init": "transfer"}, + "stoch_ak": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "init": "transfer"}, + "kda": { + "type": "kda", + "init": "transfer", + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + }, + }, + }, "mlp": {"init": "transfer"}, "normalization": {"init": "transfer"}, }, @@ -943,12 +1087,12 @@ def comprehensive_torture_chain(): # Layer 1: mamba (unchanged) # Layer 2: stoch{attn, gdn, mamba:RANDOM} # Layer 3: gdn (unchanged) - # Layer 4: stoch{attn, swa:RANDOM} (wrap in stochastic!) + # Layer 4: stoch{attn, kda, swa:RANDOM} (add swa to existing stoch_ak) # ===================================================================== { "decoder": { "type": "pattern", - "pattern": ["stoch_ams", "mamba", "stoch_agm", "gdn", "stoch_as"], + "pattern": ["stoch_ams", "mamba", "stoch_agm", "gdn", "stoch_aks"], "blocks": { "stoch_ams": { "mixer": { @@ -1006,12 +1150,18 @@ def comprehensive_torture_chain(): "mlp": {"init": "transfer"}, "normalization": {"init": "transfer"}, }, - "stoch_as": { + "stoch_aks": { "mixer": { "type": "stochastic", "main_mixer_name": "attention", "mixers": { "attention": {"type": "attention", "init": "transfer"}, + "kda": { + "type": "kda", + "init": "transfer", + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + }, "swa": { "type": "attention", "init": "random", # Random init! @@ -1035,12 +1185,12 @@ def comprehensive_torture_chain(): # Layer 1: attn (random init - type change from mamba!) # Layer 2: gdn (collapse stochastic, keep gdn) # Layer 3: swa (random init - type change from gdn!) - # Layer 4: stoch{attn, swa} (unchanged) + # Layer 4: kda (collapse stochastic, keep kda - tests KDA passthrough) # ===================================================================== { "decoder": { "type": "pattern", - "pattern": ["stoch_ms", "attn", "gdn", "swa", "stoch_as"], + "pattern": ["stoch_ms", "attn", "gdn", "swa", "kda"], "blocks": { "stoch_ms": { "mixer": { @@ -1093,18 +1243,12 @@ def comprehensive_torture_chain(): "mlp": {"init": "transfer"}, "normalization": {"init": "transfer"}, }, - "stoch_as": { + "kda": { "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - "swa": { - "type": "attention", - "init": "transfer", - "window_size": 128, - }, - }, + "type": "kda", + "init": "transfer", # Transfer from stoch's kda + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, }, "mlp": {"init": "transfer"}, "normalization": {"init": "transfer"}, @@ -1119,14 +1263,14 @@ def comprehensive_torture_chain(): # Layer 1: attention # Layer 2: gdn # Layer 3: swa - # Layer 4: stoch{attention (main), swa} - # Layers 1,3,4 have attention-based sources → can MIL/DIL to full supernet - # Layers 0,2 have mamba/gdn sources → keep structure, just transfer + # Layer 4: kda + # Layers 1,3 have attention-based sources → can MIL/DIL/KIL to full supernet + # Layers 0,2,4 have mamba/gdn/kda sources → keep structure, just transfer # ===================================================================== { "decoder": { "type": "pattern", - "pattern": ["stoch_ms", "supernet", "gdn", "supernet", "supernet"], + "pattern": ["stoch_ms", "supernet", "gdn", "supernet", "kda"], "blocks": { "stoch_ms": { # Layer 0: preserve stoch{mamba, swa} @@ -1156,7 +1300,7 @@ def comprehensive_torture_chain(): "normalization": {"init": "transfer"}, }, "supernet": { - # Layers 1,3,4: full supernet via MIL/DIL from attention + # Layers 1,3: full supernet via MIL/DIL/KIL from attention # NOTE: Explicit geometry required because this is a NEW block # and the default base (stoch_ms) is mamba-based, so geometry # can't be derived via cross-type composition. @@ -1191,11 +1335,30 @@ def comprehensive_torture_chain(): "value_head_dim": 32, "convolution_layer": {"kernel_size": 4}, }, + "kda": { + "type": "kda", + "init": "transfer", # KIL conversion + "heads": 8, + "head_dim": 32, + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + }, }, }, "mlp": {"init": "transfer"}, "normalization": {"init": "transfer"}, }, + "kda": { + # Layer 4: preserve pure kda + "mixer": { + "type": "kda", + "init": "transfer", + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + }, + "mlp": {"init": "transfer"}, + "normalization": {"init": "transfer"}, + }, }, }, }, diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache.py b/fast_llm_external_models/tests/test_apriel2/test_cache.py index 5392119a7..ca8158b4f 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_cache.py +++ b/fast_llm_external_models/tests/test_apriel2/test_cache.py @@ -1,147 +1,1258 @@ -"""Unit tests for Apriel2Cache.""" +"""Comprehensive tests for Apriel2Cache. + +Architecture Overview +===================== +Apriel2Cache manages state for autoregressive generation across different mixer types: + +1. **Attention Cache** (_AttentionCache): Stores key/value states + - Supports sliding window (window_size) for SWA + - Efficient roll optimization for single-token decode + +2. **SSM Cache** (_SSMCache): Stores conv and recurrent states + - Used by Mamba, GDN, KDA + - KDA uses tuple conv states (q, k, v), others use single tensor + +3. **Stochastic Mixer Routing**: For layers with multiple mixer options + - Each mixer has independent cache (no sharing) + - active_mixer pointer routes operations to correct sub-cache + - Switching mixers preserves each mixer's independent state + +Cache Invalidation Semantics +============================ +When switching between mixers in a stochastic layer: +- Each mixer maintains its OWN independent history +- Switching does NOT invalidate the previous mixer's cache +- Switching does NOT copy state between mixers +- To invalidate: call reset() explicitly + +This is intentional for training with stochastic sampling where each mixer +should learn from its own history. For inference, main_mixer_name is fixed. + +Test Organization +================= +1. CREATION & PROPERTIES - Cache initialization, config parsing +2. ATTENTION CACHE - Updates, sliding window, concatenation +3. SSM CACHE - Conv states, recurrent states, KDA tuples +4. STOCHASTIC ROUTING - Active mixer, isolation, switching +5. CACHE INVALIDATION - Reset, per-mixer reset, coherence +6. BEAM SEARCH - batch_repeat, reorder, select +7. HF INTEGRATION - get_mask_sizes, indexing, properties +8. GENERATION PATTERNS - Prefill→decode, crop→continue +9. ERROR HANDLING - Guards, bounds, invalid operations +""" import pytest import torch -from fast_llm_external_models.apriel2.cache import Apriel2Cache +from fast_llm_external_models.apriel2.cache import ( + Apriel2Cache, + _AttentionCache, + _SSMCache, +) -class TestCacheBasics: - """Test basic cache creation and properties.""" - def test_cache_creation(self, apriel2_config_tiny): - """Test cache creation from config.""" - cache = Apriel2Cache(apriel2_config_tiny) - num_blocks = apriel2_config_tiny.decoder["num_blocks"] - assert len(cache) == num_blocks - assert cache.is_compileable == False +# ============================================================================= +# FIXTURES - Configs and Sample Data +# ============================================================================= + + +@pytest.fixture +def tiny_attention_config(): + """Minimal config with pure attention layers.""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def swa_config(): + """Config with sliding window attention.""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "window_size": 8, # Small for testing + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def ssm_config(): + """Config with pure SSM layers (mamba).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "mamba", + "d_inner": 128, + "d_state": 16, + "dt_rank": 4, + "d_conv": 4, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def kda_config(): + """Config with pure KDA layers.""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "kda", + "heads": 4, + "head_dim": 16, + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def stochastic_config(): + """Config with stochastic mixer (attention + mamba).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "pattern", + "num_blocks": 2, + "pattern": ["attn", "stochastic"], + "blocks": { + "attn": { + "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + "stochastic": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4}, + }, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + }, + ) + + +@pytest.fixture +def all_mixers_config(): + """Config with stochastic mixer containing all 5 mixer types.""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "pattern", + "num_blocks": 2, + "pattern": ["attn", "all_mixers"], + "blocks": { + "attn": { + "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + "all_mixers": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "swa": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "window_size": 1024, + }, + "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4}, + "gdn": { + "type": "gdn", + "value_heads": 4, + "key_heads": 2, + "key_head_dim": 16, + "value_head_dim": 16, + "convolution_layer": {"kernel_size": 4}, + }, + "kda": { + "type": "kda", + "heads": 4, + "head_dim": 16, + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + }, + }, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + }, + ) + + +@pytest.fixture +def multi_window_config(): + """Config with multiple different window sizes.""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "pattern", + "num_blocks": 3, + "pattern": ["full", "small_window", "large_window"], + "blocks": { + "full": { + "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + "small_window": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "window_size": 512, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + "large_window": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "window_size": 2048, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + }, + ) + + +@pytest.fixture +def sample_kv(): + """Sample key/value tensors: [batch=2, heads=4, seq=10, head_dim=16].""" + return torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16) + + +@pytest.fixture +def sample_conv_single(): + """Sample single-tensor conv state: [batch=2, d_inner=128, kernel=4].""" + return torch.randn(2, 128, 4) + + +@pytest.fixture +def sample_conv_tuple(): + """Sample tuple conv state for KDA: (q, k, v) each [batch=2, d=64, kernel=3].""" + return (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3)) + + +@pytest.fixture +def sample_recurrent(): + """Sample recurrent state: [batch=2, heads=4, head_dim=16, d_state=16].""" + return torch.randn(2, 4, 16, 16) + + +# ============================================================================= +# SECTION 1: CACHE CREATION & PROPERTIES +# ============================================================================= + + +class TestCacheCreation: + """Test cache initialization from config.""" + + def test_attention_cache_creation(self, tiny_attention_config): + """Create cache for pure attention config.""" + cache = Apriel2Cache(tiny_attention_config) + + assert len(cache) == 2 + assert cache.mixer_types == ["attention", "attention"] + assert all(isinstance(l, _AttentionCache) for l in cache.layers) + + def test_ssm_cache_creation(self, ssm_config): + """Create cache for pure SSM config.""" + cache = Apriel2Cache(ssm_config) + + assert len(cache) == 2 + assert cache.mixer_types == ["mamba", "mamba"] + assert all(isinstance(l, _SSMCache) for l in cache.layers) + + def test_kda_cache_creation(self, kda_config): + """Create cache for pure KDA config.""" + cache = Apriel2Cache(kda_config) + + assert len(cache) == 2 + assert cache.mixer_types == ["kda", "kda"] + assert all(isinstance(l, _SSMCache) for l in cache.layers) + + def test_stochastic_cache_creation(self, stochastic_config): + """Create cache for stochastic mixer config.""" + cache = Apriel2Cache(stochastic_config) + + assert len(cache) == 2 + # Layer 0: pure attention, Layer 1: stochastic (dict) + assert isinstance(cache.layers[0], _AttentionCache) + assert isinstance(cache.layers[1], dict) + assert set(cache.layers[1].keys()) == {"attention", "mamba"} + + def test_swa_window_captured(self, swa_config): + """Verify sliding window size is captured.""" + cache = Apriel2Cache(swa_config) + + assert cache.layers[0].window == 8 + assert cache.is_sliding == [True, True] + + def test_active_mixers_initialized_none(self, stochastic_config): + """Verify active_mixers starts as None for all layers.""" + cache = Apriel2Cache(stochastic_config) + + assert cache.active_mixers == [None, None] + + +class TestCacheProperties: + """Test cache property accessors.""" + + def test_empty_cache_properties(self, tiny_attention_config): + """Test properties of uninitialized cache.""" + cache = Apriel2Cache(tiny_attention_config) + assert cache.is_initialized == False - assert isinstance(cache.is_sliding, list) - assert len(cache.is_sliding) == num_blocks + assert cache.has_previous_state == False + assert cache.max_batch_size is None + assert cache.max_cache_len is None + assert cache.is_compileable == False - def test_cache_properties_empty(self, apriel2_cache): - """Test cache properties when empty.""" - assert apriel2_cache.is_initialized == False - assert apriel2_cache.has_previous_state == False - assert apriel2_cache.max_batch_size is None - assert apriel2_cache.max_cache_len is None + def test_is_initialized_attention(self, tiny_attention_config, sample_kv): + """is_initialized detects attention cache.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) + assert cache.is_initialized == True -class TestAttentionCache: - """Test attention cache operations.""" + def test_is_initialized_ssm(self, ssm_config, sample_conv_single): + """is_initialized detects SSM cache.""" + cache = Apriel2Cache(ssm_config) + cache.conv_states[0] = sample_conv_single - def test_attention_update(self, apriel2_cache, sample_attention_states): - """Test updating attention cache.""" - key, value = sample_attention_states - k_out, v_out = apriel2_cache.update(key, value, layer_idx=0) + assert cache.is_initialized == True - assert k_out.shape == key.shape - assert v_out.shape == value.shape - assert apriel2_cache.is_initialized == True - assert apriel2_cache.get_seq_length(0) == key.shape[2] + def test_has_previous_state_ssm_only(self, ssm_config, sample_conv_single): + """has_previous_state only looks at SSM conv states.""" + cache = Apriel2Cache(ssm_config) - def test_attention_concatenation(self, apriel2_cache, sample_attention_states): - """Test that cache concatenates new states.""" - key1, value1 = sample_attention_states - apriel2_cache.update(key1, value1, layer_idx=0) + assert cache.has_previous_state == False + cache.conv_states[0] = sample_conv_single + assert cache.has_previous_state == True - # Add more tokens - key2 = torch.randn(2, 8, 5, 64) - value2 = torch.randn(2, 8, 5, 64) - k_out, v_out = apriel2_cache.update(key2, value2, layer_idx=0) + def test_has_previous_state_ignores_attention(self, tiny_attention_config, sample_kv): + """has_previous_state ignores attention cache.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) - assert k_out.shape[2] == 15 # 10 + 5 - assert apriel2_cache.get_seq_length(0) == 15 + # Attention cache is set, but has_previous_state only checks SSM + assert cache.has_previous_state == False + def test_max_batch_size_from_attention(self, tiny_attention_config, sample_kv): + """max_batch_size from attention cache.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) -class TestSSMCache: - """Test SSM cache operations.""" + assert cache.max_batch_size == 2 - def test_ssm_direct_access(self, apriel2_config_stochastic): - """Test direct SSM state access.""" - cache = Apriel2Cache(apriel2_config_stochastic) + def test_max_batch_size_from_ssm(self, ssm_config, sample_conv_single): + """max_batch_size from SSM cache.""" + cache = Apriel2Cache(ssm_config) + cache.conv_states[0] = sample_conv_single - # Set active mixer to mamba - cache.set_active_mixer(1, "mamba") + assert cache.max_batch_size == 2 + + def test_max_batch_size_from_kda_tuple(self, kda_config, sample_conv_tuple): + """max_batch_size from KDA tuple conv state.""" + cache = Apriel2Cache(kda_config) + cache.conv_states[0] = sample_conv_tuple + + assert cache.max_batch_size == 2 + + def test_max_cache_len_single_window(self, swa_config): + """max_cache_len with single window size.""" + cache = Apriel2Cache(swa_config) + assert cache.max_cache_len == 8 + + def test_max_cache_len_multiple_windows(self, multi_window_config): + """max_cache_len returns minimum window.""" + cache = Apriel2Cache(multi_window_config) + assert cache.max_cache_len == 512 # min(512, 2048) + + def test_max_cache_len_no_windows(self, tiny_attention_config): + """max_cache_len is None when no windows.""" + cache = Apriel2Cache(tiny_attention_config) + assert cache.max_cache_len is None + + def test_is_sliding_mixed(self, multi_window_config): + """is_sliding reflects per-layer window presence.""" + cache = Apriel2Cache(multi_window_config) + assert cache.is_sliding == [False, True, True] + + +# ============================================================================= +# SECTION 2: ATTENTION CACHE OPERATIONS +# ============================================================================= + + +class TestAttentionCacheBasics: + """Test basic attention cache operations.""" + + def test_update_stores_kv(self, tiny_attention_config, sample_kv): + """update() stores key/value states.""" + cache = Apriel2Cache(tiny_attention_config) + key, value = sample_kv + + k_out, v_out = cache.update(key, value, layer_idx=0) + + torch.testing.assert_close(k_out, key) + torch.testing.assert_close(v_out, value) + assert cache.get_seq_length(0) == 10 + + def test_update_concatenates(self, tiny_attention_config, sample_kv): + """Subsequent updates concatenate.""" + cache = Apriel2Cache(tiny_attention_config) + key, value = sample_kv + + cache.update(key, value, layer_idx=0) + k_out, v_out = cache.update(key, value, layer_idx=0) + + assert k_out.shape[-2] == 20 + assert cache.get_seq_length(0) == 20 + + def test_key_value_cache_accessors(self, tiny_attention_config, sample_kv): + """Test key_cache and value_cache accessors.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) - # Set conv states - conv = torch.randn(2, 128, 4) - cache.conv_states[1] = conv + assert cache.key_cache[0] is not None + assert cache.value_cache[0] is not None + torch.testing.assert_close(cache.key_cache[0], sample_kv[0]) - # Retrieve and verify - retrieved = cache.conv_states[1] - assert retrieved is not None - assert torch.allclose(retrieved, conv) +class TestSlidingWindowAttention: + """Test sliding window attention behavior.""" -class TestStochasticMixer: + def test_initial_within_window(self, swa_config): + """Initial sequence within window is kept.""" + cache = Apriel2Cache(swa_config) + key = torch.randn(2, 4, 5, 16) # seq=5 < window=8 + value = torch.randn(2, 4, 5, 16) + + cache.update(key, value, layer_idx=0) + + assert cache.get_seq_length(0) == 5 + + def test_initial_exceeds_window(self, swa_config): + """Initial sequence > window is truncated to last window tokens.""" + cache = Apriel2Cache(swa_config) + key = torch.arange(12).float().view(1, 1, 12, 1).expand(2, 4, 12, 16) + value = key.clone() + + k_out, v_out = cache.update(key, value, layer_idx=0) + + assert cache.get_seq_length(0) == 8 + # Should keep tokens 4-11 (last 8) + assert k_out[0, 0, 0, 0].item() == 4.0 + + def test_single_token_roll_path(self, swa_config): + """Single token decode with full window uses efficient roll.""" + cache = Apriel2Cache(swa_config) + + # Fill window exactly + key1 = torch.arange(8).float().view(1, 1, 8, 1).expand(2, 4, 8, 16) + cache.update(key1, key1.clone(), layer_idx=0) + + # Decode single token + key2 = torch.full((2, 4, 1, 16), 8.0) + k_out, _ = cache.update(key2, key2.clone(), layer_idx=0) + + assert cache.get_seq_length(0) == 8 + assert k_out[0, 0, 0, 0].item() == 1.0 # Token 0 rolled out + assert k_out[0, 0, 7, 0].item() == 8.0 # New token at end + + def test_multi_token_cat_slice_path(self, swa_config): + """Multiple tokens use cat+slice path.""" + cache = Apriel2Cache(swa_config) + + # Fill window + key1 = torch.randn(2, 4, 8, 16) + cache.update(key1, key1.clone(), layer_idx=0) + + # Add 3 tokens + key2 = torch.randn(2, 4, 3, 16) + k_out, _ = cache.update(key2, key2.clone(), layer_idx=0) + + assert cache.get_seq_length(0) == 8 + torch.testing.assert_close(k_out[..., -3:, :], key2) + + def test_partial_then_fill_then_overflow(self, swa_config): + """Progressive filling: partial → full → overflow.""" + cache = Apriel2Cache(swa_config) + + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + assert cache.get_seq_length(0) == 5 + + cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0) + assert cache.get_seq_length(0) == 8 + + cache.update(torch.randn(2, 4, 2, 16), torch.randn(2, 4, 2, 16), layer_idx=0) + assert cache.get_seq_length(0) == 8 + + def test_contiguous_output(self, swa_config): + """Outputs are contiguous after windowing.""" + cache = Apriel2Cache(swa_config) + + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + + assert cache.layers[0].key.is_contiguous() + assert cache.layers[0].value.is_contiguous() + + +# ============================================================================= +# SECTION 3: SSM CACHE OPERATIONS +# ============================================================================= + + +class TestSSMCacheBasics: + """Test basic SSM cache operations.""" + + def test_conv_states_accessor(self, ssm_config, sample_conv_single): + """Test conv_states accessor.""" + cache = Apriel2Cache(ssm_config) + + cache.conv_states[0] = sample_conv_single + torch.testing.assert_close(cache.conv_states[0], sample_conv_single) + + def test_recurrent_states_accessor(self, ssm_config, sample_recurrent): + """Test recurrent_states accessor.""" + cache = Apriel2Cache(ssm_config) + + cache.recurrent_states[0] = sample_recurrent + torch.testing.assert_close(cache.recurrent_states[0], sample_recurrent) + + def test_ssm_seq_length_always_zero(self, ssm_config, sample_conv_single): + """get_seq_length returns 0 for SSM (no KV cache).""" + cache = Apriel2Cache(ssm_config) + cache.conv_states[0] = sample_conv_single + + assert cache.get_seq_length(0) == 0 + + +class TestKDACache: + """Test KDA-specific cache operations with tuple conv states.""" + + def test_tuple_conv_storage(self, kda_config, sample_conv_tuple): + """KDA stores tuple conv states.""" + cache = Apriel2Cache(kda_config) + + cache.conv_states[0] = sample_conv_tuple + + assert isinstance(cache.conv_states[0], tuple) + assert len(cache.conv_states[0]) == 3 + for i in range(3): + torch.testing.assert_close(cache.conv_states[0][i], sample_conv_tuple[i]) + + def test_tuple_with_recurrent(self, kda_config, sample_conv_tuple, sample_recurrent): + """KDA can have both tuple conv and recurrent states.""" + cache = Apriel2Cache(kda_config) + + cache.conv_states[0] = sample_conv_tuple + cache.recurrent_states[0] = sample_recurrent + + assert isinstance(cache.conv_states[0], tuple) + assert cache.recurrent_states[0] is not None + + def test_has_previous_state_detects_tuple(self, kda_config, sample_conv_tuple): + """has_previous_state works with tuple conv states.""" + cache = Apriel2Cache(kda_config) + + assert cache.has_previous_state == False + cache.conv_states[0] = sample_conv_tuple + assert cache.has_previous_state == True + + +# ============================================================================= +# SECTION 4: STOCHASTIC ROUTING +# ============================================================================= + + +class TestStochasticRouting: """Test stochastic mixer cache routing.""" - def test_set_active_mixer(self, apriel2_config_stochastic): - """Test setting active mixer.""" - cache = Apriel2Cache(apriel2_config_stochastic) + def test_set_active_mixer(self, stochastic_config): + """set_active_mixer sets the pointer.""" + cache = Apriel2Cache(stochastic_config) + cache.set_active_mixer(1, "attention") assert cache.active_mixers[1] == "attention" - def test_routing_to_different_mixers(self, apriel2_config_stochastic, sample_attention_states): - """Test that different mixers use separate caches.""" - cache = Apriel2Cache(apriel2_config_stochastic) - key, value = sample_attention_states + cache.set_active_mixer(1, "mamba") + assert cache.active_mixers[1] == "mamba" + + def test_operations_route_to_active(self, stochastic_config, sample_kv): + """Operations route to currently active mixer.""" + cache = Apriel2Cache(stochastic_config) - # Use attention mixer cache.set_active_mixer(1, "attention") - cache.update(key, value, layer_idx=1) + cache.update(*sample_kv, layer_idx=1) attn_len = cache.get_seq_length(1) - # Switch to mamba mixer - should have empty cache cache.set_active_mixer(1, "mamba") mamba_len = cache.get_seq_length(1) assert attn_len == 10 - assert mamba_len == 0 # Different cache + assert mamba_len == 0 # Mamba cache is separate and empty + + def test_each_mixer_independent_cache(self, stochastic_config, sample_kv, sample_conv_single): + """Each mixer maintains independent cache.""" + cache = Apriel2Cache(stochastic_config) + + # Fill attention cache + cache.set_active_mixer(1, "attention") + cache.update(*sample_kv, layer_idx=1) + + # Fill mamba cache + cache.set_active_mixer(1, "mamba") + cache.conv_states[1] = sample_conv_single + + # Both preserved + cache.set_active_mixer(1, "attention") + assert cache.get_seq_length(1) == 10 + + cache.set_active_mixer(1, "mamba") + torch.testing.assert_close(cache.conv_states[1], sample_conv_single) + + +class TestMixerSwitching: + """Test behavior when switching between mixers mid-generation.""" + + def test_switch_preserves_previous_state(self, stochastic_config, sample_kv): + """Switching mixers preserves previous mixer's state.""" + cache = Apriel2Cache(stochastic_config) + + cache.set_active_mixer(1, "attention") + cache.update(*sample_kv, layer_idx=1) + original_key = cache.layers[1]["attention"].key.clone() + + # Switch to mamba, do something + cache.set_active_mixer(1, "mamba") + cache.conv_states[1] = torch.randn(2, 128, 4) + + # Switch back - attention unchanged + cache.set_active_mixer(1, "attention") + torch.testing.assert_close(cache.layers[1]["attention"].key, original_key) + + def test_switch_does_not_copy_state(self, stochastic_config, sample_kv): + """Switching does NOT copy state between mixers.""" + cache = Apriel2Cache(stochastic_config) + + # Fill attention with 10 tokens + cache.set_active_mixer(1, "attention") + cache.update(*sample_kv, layer_idx=1) + + # Switch to mamba - it has NO history from attention + cache.set_active_mixer(1, "mamba") + assert cache.conv_states[1] is None + assert cache.recurrent_states[1] is None + + def test_has_previous_state_checks_all_sub_caches(self, stochastic_config): + """has_previous_state checks ALL sub-caches, not just active.""" + cache = Apriel2Cache(stochastic_config) + + cache.set_active_mixer(1, "mamba") + cache.conv_states[1] = torch.randn(2, 128, 4) + + # Even if we switch away, has_previous_state still detects it + cache.set_active_mixer(1, "attention") + assert cache.has_previous_state == True -class TestBeamSearch: - """Test beam search operations.""" +class TestAllMixerTypes: + """Test cache isolation across all 5 mixer types.""" - def test_batch_repeat_interleave(self, apriel2_cache, sample_attention_states): - """Test repeating cache for beam search.""" - key, value = sample_attention_states - apriel2_cache.update(key, value, layer_idx=0) + def test_all_five_mixer_types_isolated(self, all_mixers_config): + """All 5 mixer types maintain isolated caches.""" + cache = Apriel2Cache(all_mixers_config) + layer_idx = 1 # Stochastic layer - apriel2_cache.batch_repeat_interleave(2) - assert apriel2_cache.max_batch_size == 4 # 2 * 2 + # Fill each mixer's cache + cache.set_active_mixer(layer_idx, "attention") + attn_kv = (torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16)) + cache.update(*attn_kv, layer_idx=layer_idx) - def test_reorder_cache(self, apriel2_cache, sample_attention_states): - """Test reordering cache for beam search.""" - key, value = sample_attention_states - apriel2_cache.update(key, value, layer_idx=0) + cache.set_active_mixer(layer_idx, "swa") + swa_kv = (torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16)) + cache.update(*swa_kv, layer_idx=layer_idx) + + cache.set_active_mixer(layer_idx, "mamba") + mamba_conv = torch.randn(2, 128, 4) + cache.conv_states[layer_idx] = mamba_conv + + cache.set_active_mixer(layer_idx, "gdn") + gdn_conv = torch.randn(2, 64, 3) + cache.conv_states[layer_idx] = gdn_conv + + cache.set_active_mixer(layer_idx, "kda") + kda_conv = (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3)) + cache.conv_states[layer_idx] = kda_conv + + # Verify all preserved + cache.set_active_mixer(layer_idx, "attention") + assert cache.get_seq_length(layer_idx) == 10 + + cache.set_active_mixer(layer_idx, "swa") + assert cache.get_seq_length(layer_idx) == 5 + + cache.set_active_mixer(layer_idx, "mamba") + torch.testing.assert_close(cache.conv_states[layer_idx], mamba_conv) + + cache.set_active_mixer(layer_idx, "gdn") + torch.testing.assert_close(cache.conv_states[layer_idx], gdn_conv) + + cache.set_active_mixer(layer_idx, "kda") + assert isinstance(cache.conv_states[layer_idx], tuple) + + +# ============================================================================= +# SECTION 5: CACHE INVALIDATION +# ============================================================================= + + +class TestCacheInvalidation: + """Test cache invalidation and reset semantics. + + Key principle: Each mixer maintains independent state. To invalidate: + - reset() clears ALL caches across ALL layers and mixers + - There is no per-mixer reset (by design - each mixer is independent) + """ + + def test_reset_clears_attention(self, tiny_attention_config, sample_kv): + """reset() clears attention cache.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) + + cache.reset() + + assert cache.is_initialized == False + assert cache.get_seq_length(0) == 0 + + def test_reset_clears_ssm(self, ssm_config, sample_conv_single, sample_recurrent): + """reset() clears SSM cache.""" + cache = Apriel2Cache(ssm_config) + cache.conv_states[0] = sample_conv_single + cache.recurrent_states[0] = sample_recurrent + + cache.reset() + + assert cache.has_previous_state == False + assert cache.conv_states[0] is None + assert cache.recurrent_states[0] is None + + def test_reset_clears_kda_tuple(self, kda_config, sample_conv_tuple): + """reset() clears KDA tuple conv states.""" + cache = Apriel2Cache(kda_config) + cache.conv_states[0] = sample_conv_tuple + + cache.reset() + + assert cache.conv_states[0] is None + + def test_reset_clears_all_stochastic_mixers(self, all_mixers_config): + """reset() clears ALL mixer caches in stochastic layer.""" + cache = Apriel2Cache(all_mixers_config) + layer_idx = 1 + + # Fill all mixers + cache.set_active_mixer(layer_idx, "attention") + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx) + + cache.set_active_mixer(layer_idx, "mamba") + cache.conv_states[layer_idx] = torch.randn(2, 128, 4) + + cache.set_active_mixer(layer_idx, "kda") + cache.conv_states[layer_idx] = (torch.randn(2, 64, 3),) * 3 + + cache.reset() + + # All cleared + assert cache.layers[layer_idx]["attention"].key is None + assert cache.layers[layer_idx]["mamba"].conv is None + assert cache.layers[layer_idx]["kda"].conv is None + + def test_crop_truncates_attention(self, tiny_attention_config, sample_kv): + """crop() truncates attention cache to max_length.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) + + cache.crop(5) + + assert cache.get_seq_length(0) == 5 + + def test_crop_affects_all_layers(self, tiny_attention_config, sample_kv): + """crop() affects all layers.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) + cache.update(*sample_kv, layer_idx=1) + + cache.crop(3) + + assert cache.get_seq_length(0) == 3 + assert cache.get_seq_length(1) == 3 + + def test_crop_ignores_ssm(self, ssm_config, sample_conv_single): + """crop() only affects attention, not SSM.""" + cache = Apriel2Cache(ssm_config) + cache.conv_states[0] = sample_conv_single + + cache.crop(5) # Should not crash + + # Conv state unchanged + torch.testing.assert_close(cache.conv_states[0], sample_conv_single) + + +# ============================================================================= +# SECTION 6: BEAM SEARCH OPERATIONS +# ============================================================================= + + +class TestBatchRepeatInterleave: + """Test batch_repeat_interleave for beam search expansion.""" + + def test_repeat_attention(self, tiny_attention_config, sample_kv): + """Repeat attention cache for beam search.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) + + cache.batch_repeat_interleave(3) + + assert cache.max_batch_size == 6 # 2 * 3 + + def test_repeat_ssm(self, ssm_config, sample_conv_single, sample_recurrent): + """Repeat SSM cache for beam search.""" + cache = Apriel2Cache(ssm_config) + cache.conv_states[0] = sample_conv_single + cache.recurrent_states[0] = sample_recurrent + + cache.batch_repeat_interleave(4) + + assert cache.conv_states[0].shape[0] == 8 # 2 * 4 + assert cache.recurrent_states[0].shape[0] == 8 + + def test_repeat_kda_tuple(self, kda_config, sample_conv_tuple): + """Repeat KDA tuple conv states.""" + cache = Apriel2Cache(kda_config) + cache.conv_states[0] = sample_conv_tuple + + cache.batch_repeat_interleave(3) + + for c in cache.conv_states[0]: + assert c.shape[0] == 6 + + def test_repeat_stochastic_all_mixers(self, all_mixers_config): + """Repeat all mixer caches in stochastic layer.""" + cache = Apriel2Cache(all_mixers_config) + layer_idx = 1 + + cache.set_active_mixer(layer_idx, "attention") + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx) + + cache.set_active_mixer(layer_idx, "mamba") + cache.conv_states[layer_idx] = torch.randn(2, 128, 4) + + cache.batch_repeat_interleave(2) + + cache.set_active_mixer(layer_idx, "attention") + assert cache.layers[layer_idx]["attention"].key.shape[0] == 4 + + cache.set_active_mixer(layer_idx, "mamba") + assert cache.conv_states[layer_idx].shape[0] == 4 + + def test_repeat_skips_none(self, tiny_attention_config): + """Repeat gracefully skips None caches.""" + cache = Apriel2Cache(tiny_attention_config) + # Don't fill anything + + cache.batch_repeat_interleave(3) # Should not crash + + assert cache.max_batch_size is None + + +class TestReorderCache: + """Test reorder_cache for beam search hypothesis selection.""" + + def test_reorder_attention(self, tiny_attention_config, sample_kv): + """Reorder attention cache.""" + cache = Apriel2Cache(tiny_attention_config) + key, value = sample_kv + # Make batches distinguishable + key = torch.arange(2).float().view(2, 1, 1, 1).expand(2, 4, 10, 16) + cache.update(key, key.clone(), layer_idx=0) + + beam_idx = torch.tensor([1, 0]) + cache.reorder_cache(beam_idx) + + assert cache.layers[0].key[0, 0, 0, 0].item() == 1.0 + assert cache.layers[0].key[1, 0, 0, 0].item() == 0.0 + + def test_reorder_ssm(self, ssm_config): + """Reorder SSM cache.""" + cache = Apriel2Cache(ssm_config) + conv = torch.arange(2).float().view(2, 1, 1).expand(2, 128, 4) + cache.conv_states[0] = conv.clone() + + beam_idx = torch.tensor([1, 0]) + cache.reorder_cache(beam_idx) + + assert cache.conv_states[0][0, 0, 0].item() == 1.0 + + def test_reorder_kda_tuple(self, kda_config): + """Reorder KDA tuple conv states.""" + cache = Apriel2Cache(kda_config) + conv_q = torch.arange(2).float().view(2, 1, 1).expand(2, 64, 3) + cache.conv_states[0] = (conv_q.clone(), conv_q.clone(), conv_q.clone()) beam_idx = torch.tensor([1, 0]) - apriel2_cache.reorder_cache(beam_idx) + cache.reorder_cache(beam_idx) + + for c in cache.conv_states[0]: + assert c[0, 0, 0].item() == 1.0 + + +class TestBatchSelectIndices: + """Test batch_select_indices for beam selection.""" + + def test_select_attention(self, tiny_attention_config, sample_kv): + """Select subset of attention cache.""" + cache = Apriel2Cache(tiny_attention_config) + key = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16) + cache.update(key, key.clone(), layer_idx=0) + + indices = torch.tensor([0, 3]) + cache.batch_select_indices(indices) + + assert cache.max_batch_size == 2 + assert cache.layers[0].key[0, 0, 0, 0].item() == 0.0 + assert cache.layers[0].key[1, 0, 0, 0].item() == 3.0 + + def test_select_kda_tuple(self, kda_config): + """Select subset of KDA tuple conv states.""" + cache = Apriel2Cache(kda_config) + conv = tuple(torch.arange(4).float().view(4, 1, 1).expand(4, 64, 3).clone() for _ in range(3)) + cache.conv_states[0] = conv + + indices = torch.tensor([1, 2]) + cache.batch_select_indices(indices) + + for c in cache.conv_states[0]: + assert c.shape[0] == 2 + assert c[0, 0, 0].item() == 1.0 + + +# ============================================================================= +# SECTION 7: HUGGINGFACE INTEGRATION +# ============================================================================= + + +class TestGetMaskSizes: + """Test get_mask_sizes() for attention mask computation.""" + + def test_empty_cache(self, tiny_attention_config): + """Mask sizes with empty cache.""" + cache = Apriel2Cache(tiny_attention_config) + cache_position = torch.arange(10) + + kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert kv_length == 10 + assert kv_offset == 0 + + def test_with_cached_tokens(self, tiny_attention_config, sample_kv): + """Mask sizes with cached tokens.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) # 10 tokens + + cache_position = torch.arange(5) + kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert kv_length == 15 # 10 + 5 + assert kv_offset == 10 + + def test_single_token_decode(self, tiny_attention_config, sample_kv): + """Mask sizes for single token decode.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) + + cache_position = torch.arange(1) + kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert kv_length == 11 + assert kv_offset == 10 + + def test_ssm_returns_query_only(self, ssm_config, sample_conv_single): + """SSM layers return query_length (no KV cache).""" + cache = Apriel2Cache(ssm_config) + cache.conv_states[0] = sample_conv_single + + cache_position = torch.arange(5) + kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert kv_length == 5 + assert kv_offset == 0 + + +class TestCacheIndexing: + """Test cache[idx] indexing.""" + + def test_attention_returns_kv(self, tiny_attention_config, sample_kv): + """Indexing attention layer returns (key, value).""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) + + result = cache[0] + + assert isinstance(result, tuple) + torch.testing.assert_close(result[0], sample_kv[0]) + + def test_empty_returns_empty_tensors(self, tiny_attention_config): + """Indexing empty layer returns empty tensors.""" + cache = Apriel2Cache(tiny_attention_config) + + result = cache[0] + + assert result[0].numel() == 0 + assert result[1].numel() == 0 + + def test_ssm_returns_empty(self, ssm_config, sample_conv_single): + """Indexing SSM layer returns empty (no KV).""" + cache = Apriel2Cache(ssm_config) + cache.conv_states[0] = sample_conv_single + + result = cache[0] + + assert result[0].numel() == 0 + + def test_stochastic_attention_returns_kv(self, stochastic_config, sample_kv): + """Indexing stochastic with attention active returns KV.""" + cache = Apriel2Cache(stochastic_config) + cache.set_active_mixer(1, "attention") + cache.update(*sample_kv, layer_idx=1) + + result = cache[1] + + torch.testing.assert_close(result[0], sample_kv[0]) + + +# ============================================================================= +# SECTION 8: GENERATION PATTERNS +# ============================================================================= + + +class TestGenerationPatterns: + """Test real-world generation patterns.""" + + def test_prefill_then_decode(self, tiny_attention_config, sample_kv): + """Prefill with long prompt, then decode token-by-token.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) # Prefill 10 tokens + + for _ in range(5): + new_kv = (torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16)) + cache.update(*new_kv, layer_idx=0) + + assert cache.get_seq_length(0) == 15 + + def test_crop_then_continue(self, tiny_attention_config, sample_kv): + """Crop old context, continue generation.""" + cache = Apriel2Cache(tiny_attention_config) + cache.update(*sample_kv, layer_idx=0) + cache.update(*sample_kv, layer_idx=0) # 20 tokens + + cache.crop(5) # Keep last 5 + cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0) + + assert cache.get_seq_length(0) == 8 + + def test_reset_between_generations(self, tiny_attention_config, sample_kv): + """Reset between independent generations.""" + cache = Apriel2Cache(tiny_attention_config) + + # First generation + cache.update(*sample_kv, layer_idx=0) + assert cache.is_initialized == True + + # Reset + cache.reset() + assert cache.is_initialized == False + + # Second generation + cache.update(*sample_kv, layer_idx=0) + assert cache.get_seq_length(0) == 10 + + def test_multi_layer_consistency(self, tiny_attention_config, sample_kv): + """All layers updated consistently.""" + cache = Apriel2Cache(tiny_attention_config) + + for layer_idx in range(2): + cache.update(*sample_kv, layer_idx=layer_idx) + cache.update(torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16), layer_idx=layer_idx) + + for layer_idx in range(2): + assert cache.get_seq_length(layer_idx) == 11 + + +# ============================================================================= +# SECTION 9: ERROR HANDLING +# ============================================================================= + + +class TestErrorHandling: + """Test error conditions and guards.""" + + def test_stochastic_update_without_active_mixer(self, stochastic_config): + """update() on stochastic without active_mixer raises.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="needs active_mixer set"): + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) + + def test_stochastic_accessor_without_active_mixer(self, stochastic_config): + """Accessing stochastic cache without active_mixer raises.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="requires set_active_mixer"): + _ = cache.conv_states[1] + + def test_accessor_error_lists_available_mixers(self, stochastic_config): + """Error message lists available mixers.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="Available mixers:"): + _ = cache.key_cache[1] + + def test_invalid_mixer_name(self, stochastic_config): + """Invalid mixer name raises KeyError on access.""" + cache = Apriel2Cache(stochastic_config) + cache.set_active_mixer(1, "nonexistent") + + with pytest.raises(KeyError): + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) + + def test_layer_idx_out_of_bounds(self, tiny_attention_config): + """Out-of-bounds layer_idx raises IndexError.""" + cache = Apriel2Cache(tiny_attention_config) + + with pytest.raises(IndexError): + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=999) + + +# ============================================================================= +# SECTION 10: INTERNAL CLASSES +# ============================================================================= + + +class TestAttentionCacheInternal: + """Test internal _AttentionCache class directly.""" + + def test_unbounded_growth(self): + """No window allows unbounded growth.""" + cache = _AttentionCache(window=None) + + for _ in range(10): + cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16)) + + assert cache.key.shape[-2] == 1000 - # Cache should still be valid - assert apriel2_cache.is_initialized == True + def test_window_enforced(self): + """Window caps cache size.""" + cache = _AttentionCache(window=50) + for _ in range(10): + cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16)) -class TestCacheReset: - """Test cache reset operations.""" + assert cache.key.shape[-2] == 50 - def test_reset(self, apriel2_cache, sample_attention_states): - """Test resetting cache.""" - key, value = sample_attention_states - apriel2_cache.update(key, value, layer_idx=0) - assert apriel2_cache.is_initialized == True +class TestSSMCacheInternal: + """Test internal _SSMCache class directly.""" - apriel2_cache.reset() + def test_initial_none(self): + """Initial states are None.""" + cache = _SSMCache() - assert apriel2_cache.is_initialized == False - assert apriel2_cache.get_seq_length(0) == 0 + assert cache.conv is None + assert cache.recurrent is None - def test_crop(self, apriel2_cache, sample_attention_states): - """Test cropping cache to max length.""" - key, value = sample_attention_states - apriel2_cache.update(key, value, layer_idx=0) + def test_stores_tuple(self): + """Can store tuple (for KDA).""" + cache = _SSMCache() + cache.conv = (torch.randn(2, 64, 3),) * 3 - apriel2_cache.crop(5) - assert apriel2_cache.get_seq_length(0) == 5 + assert isinstance(cache.conv, tuple) diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_routing.py b/fast_llm_external_models/tests/test_apriel2/test_cache_routing.py deleted file mode 100644 index a37cf945c..000000000 --- a/fast_llm_external_models/tests/test_apriel2/test_cache_routing.py +++ /dev/null @@ -1,291 +0,0 @@ -"""Tests for stochastic mixer cache routing and bug fixes.""" - -import pytest -import torch -from fast_llm_external_models.apriel2.cache import Apriel2Cache - - -class TestHasPreviousState: - """Test has_previous_state property with stochastic mixers.""" - - def test_checks_all_sub_caches(self, apriel2_config_stochastic): - """Test that has_previous_state checks ALL sub-caches, not just main mixer.""" - cache = Apriel2Cache(apriel2_config_stochastic) - - # Initially no SSM state - assert cache.has_previous_state == False - - # Set active mixer to mamba (NOT the main mixer which is attention) - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = torch.randn(2, 128, 4) - - # Should detect SSM state even though main mixer is "attention" - assert cache.has_previous_state == True - - def test_detects_any_ssm_cache(self, apriel2_config_multi_mixer): - """Test that has_previous_state detects SSM state in any sub-cache.""" - cache = Apriel2Cache(apriel2_config_multi_mixer) - - # Fill mamba_v1 - cache.set_active_mixer(0, "mamba_v1") - cache.conv_states[0] = torch.randn(2, 128, 4) - - # Fill mamba_v2 - cache.set_active_mixer(0, "mamba_v2") - cache.conv_states[0] = torch.randn(2, 128, 4) - - # Should detect SSM state from either variant - assert cache.has_previous_state == True - - -class TestPropertyAccessorGuards: - """Test that property accessors guard against None active_mixer.""" - - def test_get_raises_error_without_active_mixer(self, apriel2_config_stochastic): - """Test that accessing cache without set_active_mixer raises clear error.""" - cache = Apriel2Cache(apriel2_config_stochastic) - - with pytest.raises(RuntimeError) as exc_info: - _ = cache.conv_states[1] - - assert "requires set_active_mixer()" in str(exc_info.value) - assert "Available mixers:" in str(exc_info.value) - - def test_set_raises_error_without_active_mixer(self, apriel2_config_stochastic): - """Test that setting cache without set_active_mixer raises clear error.""" - cache = Apriel2Cache(apriel2_config_stochastic) - - with pytest.raises(RuntimeError) as exc_info: - cache.conv_states[1] = torch.randn(2, 128, 4) - - assert "requires set_active_mixer()" in str(exc_info.value) - - def test_access_works_after_set_active_mixer(self, apriel2_config_stochastic): - """Test that access works correctly after set_active_mixer.""" - cache = Apriel2Cache(apriel2_config_stochastic) - - # Set active mixer - cache.set_active_mixer(1, "mamba") - - # Now access should work - cache.conv_states[1] = torch.randn(2, 128, 4) - retrieved = cache.conv_states[1] - - assert retrieved is not None - - -class TestMixerSwitching: - """Test cache behavior when switching between different mixers.""" - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="SSM mixers require CUDA") - def test_cache_preserves_state_across_mixer_switches(self, apriel2_config_all_mixers, device): - """Verify cache maintains independent state for each mixer when switching. - - This is the critical test for stochastic mixers: when we switch which mixer - is active, the cache must preserve previous mixer states while updating the - current mixer's state. - """ - if device.type != "cuda": - pytest.skip("SSM mixers require CUDA device") - - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM - - model = Apriel2ForCausalLM(apriel2_config_all_mixers).to(device) - model.eval() - - stochastic_layer_idx = 1 # Layer 1 is the stochastic layer - stochastic_layer = model.model.decoder.blocks[stochastic_layer_idx] - input_ids = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 10), device=device) - - # Forward 1: Use attention (default main mixer) - stochastic_layer.mixer.main_mixer_name = "attention" - outputs1 = model(input_ids, use_cache=True) - cache = outputs1.past_key_values - - # Verify: only attention has data - layer_cache = cache.layers[stochastic_layer_idx] - assert layer_cache['attention'].key is not None, "Attention cache should have KV states" - assert layer_cache['swa'].key is None, "SWA cache should be empty" - assert layer_cache['mamba'].conv is None, "Mamba cache should be empty" - assert layer_cache['gdn'].conv is None, "GatedDeltaNet cache should be empty" - attn_seq_len_1 = layer_cache['attention'].key.shape[-2] - - # Forward 2: Switch to mamba (new token) - stochastic_layer.mixer.main_mixer_name = "mamba" - new_token = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 1), device=device) - outputs2 = model(new_token, past_key_values=cache, use_cache=True) - cache = outputs2.past_key_values - - # Verify: attention preserved, mamba added - assert layer_cache['attention'].key is not None, "Attention cache should be preserved" - assert layer_cache['attention'].key.shape[-2] == attn_seq_len_1, "Attention seq_len should not change" - assert layer_cache['mamba'].conv is not None, "Mamba cache should now have SSM states" - assert layer_cache['swa'].key is None, "SWA cache should still be empty" - assert layer_cache['gdn'].conv is None, "GatedDeltaNet cache should still be empty" - - # Forward 3: Switch to swa - stochastic_layer.mixer.main_mixer_name = "swa" - outputs3 = model(new_token, past_key_values=cache, use_cache=True) - cache = outputs3.past_key_values - - # Verify: attention + mamba preserved, swa added - assert layer_cache['attention'].key is not None, "Attention cache should be preserved" - assert layer_cache['mamba'].conv is not None, "Mamba cache should be preserved" - assert layer_cache['swa'].key is not None, "SWA cache should now have KV states" - assert layer_cache['gdn'].conv is None, "GatedDeltaNet cache should still be empty" - - # Forward 4: Switch to gated_delta_net - stochastic_layer.mixer.main_mixer_name = "gdn" - outputs4 = model(new_token, past_key_values=cache, use_cache=True) - cache = outputs4.past_key_values - - # Verify: ALL mixers now have independent state - assert layer_cache['attention'].key is not None, "Attention cache should be preserved" - assert layer_cache['mamba'].conv is not None, "Mamba cache should be preserved" - assert layer_cache['swa'].key is not None, "SWA cache should be preserved" - assert layer_cache['gdn'].conv is not None, "GatedDeltaNet cache should now have SSM states" - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="SSM mixers require CUDA") - def test_cache_isolation_between_attention_and_ssm(self, apriel2_config_all_mixers, device): - """Verify attention caches (KV) and SSM caches (conv/recurrent) don't interfere.""" - if device.type != "cuda": - pytest.skip("SSM mixers require CUDA device") - - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM - - model = Apriel2ForCausalLM(apriel2_config_all_mixers).to(device) - model.eval() - - stochastic_layer_idx = 1 - stochastic_layer = model.model.decoder.blocks[stochastic_layer_idx] - input_ids = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 10), device=device) - - # Forward with attention - stochastic_layer.mixer.main_mixer_name = "attention" - outputs1 = model(input_ids, use_cache=True) - cache = outputs1.past_key_values - - # Get attention cache state - attn_cache = cache.layers[stochastic_layer_idx]['attention'] - attn_key = attn_cache.key.clone() - attn_value = attn_cache.value.clone() - - # Forward with mamba (using same cache) - stochastic_layer.mixer.main_mixer_name = "mamba" - new_token = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 1), device=device) - outputs2 = model(new_token, past_key_values=cache, use_cache=True) - cache = outputs2.past_key_values - - # Verify attention cache unchanged - assert torch.allclose(cache.layers[stochastic_layer_idx]['attention'].key, attn_key), \ - "Attention KV cache should not be modified when mamba is active" - assert torch.allclose(cache.layers[stochastic_layer_idx]['attention'].value, attn_value), \ - "Attention KV cache should not be modified when mamba is active" - - # Verify mamba cache is populated - assert cache.layers[stochastic_layer_idx]['mamba'].conv is not None, \ - "Mamba SSM cache should be populated" - - def test_seq_len_tracking_per_mixer(self, apriel2_config_all_mixers): - """Verify seq_len is tracked independently for each mixer.""" - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM - - model = Apriel2ForCausalLM(apriel2_config_all_mixers) - model.eval() - - stochastic_layer_idx = 1 - stochastic_layer = model.model.decoder.blocks[stochastic_layer_idx] - - # Forward with attention (10 tokens) - input_ids1 = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 10)) - stochastic_layer.mixer.main_mixer_name = "attention" - outputs1 = model(input_ids1, use_cache=True) - cache = outputs1.past_key_values - - cache.set_active_mixer(stochastic_layer_idx, "attention") - assert cache.get_seq_length(stochastic_layer_idx) == 10 - - # Forward with swa (5 tokens) - independent from attention - input_ids2 = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 5)) - stochastic_layer.mixer.main_mixer_name = "swa" - outputs2 = model(input_ids2, use_cache=True) - cache2 = Apriel2Cache(apriel2_config_all_mixers) # Fresh cache for swa - outputs2 = model(input_ids2, past_key_values=cache2, use_cache=True) - cache2 = outputs2.past_key_values - - cache2.set_active_mixer(stochastic_layer_idx, "swa") - assert cache2.get_seq_length(stochastic_layer_idx) == 5 - - # Original cache should still have attention with seq_len=10 - cache.set_active_mixer(stochastic_layer_idx, "attention") - assert cache.get_seq_length(stochastic_layer_idx) == 10 - - -class TestMultipleMixersSameType: - """Test multiple mixers of the same type with independent caches.""" - - def test_attention_variants_independent(self, apriel2_config_multi_mixer): - """Test that different attention mixers have independent caches.""" - cache = Apriel2Cache(apriel2_config_multi_mixer) - - # Fill attn_small cache - cache.set_active_mixer(0, "attn_small") - key_small = torch.randn(2, 8, 10, 64) - value_small = torch.randn(2, 8, 10, 64) - cache.update(key_small, value_small, 0) - - assert cache.get_seq_length(0) == 10 - - # Switch to attn_large - should have empty cache - cache.set_active_mixer(0, "attn_large") - assert cache.get_seq_length(0) == 0 - - # Fill attn_large - key_large = torch.randn(2, 8, 5, 64) - value_large = torch.randn(2, 8, 5, 64) - cache.update(key_large, value_large, 0) - - assert cache.get_seq_length(0) == 5 - - # Switch back to attn_small - should still have original data - cache.set_active_mixer(0, "attn_small") - assert cache.get_seq_length(0) == 10 - - def test_ssm_variants_independent(self, apriel2_config_multi_mixer): - """Test that different SSM mixers have independent caches.""" - cache = Apriel2Cache(apriel2_config_multi_mixer) - - # Fill mamba_v1 - cache.set_active_mixer(0, "mamba_v1") - conv1 = torch.randn(2, 128, 4) - cache.conv_states[0] = conv1 - - # Fill mamba_v2 - cache.set_active_mixer(0, "mamba_v2") - conv2 = torch.randn(2, 128, 4) - cache.conv_states[0] = conv2 - - # Verify they're different - cache.set_active_mixer(0, "mamba_v1") - retrieved1 = cache.conv_states[0] - - cache.set_active_mixer(0, "mamba_v2") - retrieved2 = cache.conv_states[0] - - assert not torch.allclose(retrieved1, retrieved2) - assert torch.allclose(retrieved1, conv1) - assert torch.allclose(retrieved2, conv2) - - def test_different_window_sizes(self, apriel2_config_multi_mixer): - """Test that attention mixers with different window sizes are independent.""" - cache = Apriel2Cache(apriel2_config_multi_mixer) - - # Check that attn_small and attn_large have different window sizes - cache.set_active_mixer(0, "attn_small") - window_small = cache.get_max_cache_shape(0) - - cache.set_active_mixer(0, "attn_large") - window_large = cache.get_max_cache_shape(0) - - assert window_small == 2048 - assert window_large == 8192 diff --git a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py new file mode 100644 index 000000000..ec6abc1d2 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py @@ -0,0 +1,543 @@ +"""Tests for CausalConv1d consistency across all code paths. + +The Key Consistency Property +============================ +For ANY input sequence, ALL of the following must produce the SAME output: + +1. Prefill entire sequence at once (CPU/PyTorch fallback) +2. Prefill entire sequence at once (CUDA fast path) +3. Prefill in chunks with state passing (CPU) +4. Prefill in chunks with state passing (CUDA) +5. Prefill prefix + decode remaining tokens one-by-one (CPU) +6. Prefill prefix + decode remaining tokens one-by-one (CUDA) +7. Mixed: CUDA prefill → CPU decode +8. Mixed: CPU prefill → CUDA decode + +This is critical because during inference: +- Prefill processes the prompt (potentially chunked for long prompts) +- Decode generates tokens one at a time +- If these paths diverge, generation quality degrades silently +""" + +import pytest +import torch + +from fast_llm_external_models.apriel2.modeling_apriel2 import CausalConv1d, _causal_conv1d_fn + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def conv(): + """CausalConv1d layer with fixed random weights (on CPU).""" + torch.manual_seed(42) + return CausalConv1d( + in_channels=64, + out_channels=64, + kernel_size=4, + groups=64, + bias=True, + activation="silu", + device="cpu", + ) + + +@pytest.fixture +def dim(): + return 64 + + +@pytest.fixture +def kernel_size(): + return 4 + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def to_device(conv: CausalConv1d, device: str) -> CausalConv1d: + """Create a copy of conv on the specified device.""" + import copy + return copy.deepcopy(conv).to(device) + + +def prefill(conv: CausalConv1d, x: torch.Tensor, state: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]: + """Prefill and return (output, final_state).""" + return conv(x, conv_state=state, return_final_state=True) + + +def decode_sequence(conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Decode multiple tokens one-by-one, return (stacked_outputs, final_state). + + Args: + conv: CausalConv1d layer + tokens: [batch, dim, num_tokens] - tokens to decode + state: [batch, dim, kernel_size-1] - initial state (modified in-place) + + Returns: + outputs: [batch, dim, num_tokens] - output for each token + state: final state after all tokens + """ + outputs = [] + for i in range(tokens.shape[-1]): + token = tokens[:, :, i] + out = conv.update(token, state) + outputs.append(out) + return torch.stack(outputs, dim=-1), state + + +# ============================================================================= +# Unit Tests +# ============================================================================= + + +class TestCausalConv1dBasics: + """Basic functionality tests.""" + + def test_output_shape(self, conv, dim): + """Output shape matches input shape.""" + x = torch.randn(2, dim, 16, device="cpu") + out = conv(x) + assert out.shape == x.shape + + def test_state_shape(self, conv, dim, kernel_size): + """Returned state has correct shape.""" + x = torch.randn(2, dim, 16, device="cpu") + out, state = conv(x, return_final_state=True) + assert state.shape == (2, dim, kernel_size - 1) + + def test_deterministic(self, conv, dim): + """Same input produces same output.""" + x = torch.randn(2, dim, 16, device="cpu") + out1 = conv(x) + out2 = conv(x) + torch.testing.assert_close(out1, out2) + + def test_update_output_shape(self, conv, dim, kernel_size): + """Update produces single token output.""" + token = torch.randn(2, dim, device="cpu") + state = torch.randn(2, dim, kernel_size - 1, device="cpu") + out = conv.update(token, state) + assert out.shape == (2, dim) + + def test_fast_path_detection(self, conv, dim): + """Fast path correctly detected based on device.""" + x_cpu = torch.randn(2, dim, 16, device="cpu") + assert not conv._use_fast_path(x_cpu) + + if torch.cuda.is_available(): + x_cuda = torch.randn(2, dim, 16, device="cuda") + conv_cuda = conv.cuda() + # Fast path available only if CUDA kernels installed + expected = _causal_conv1d_fn is not None + assert conv_cuda._use_fast_path(x_cuda) == expected + + +# ============================================================================= +# Backend Equivalence (CUDA vs CPU) +# ============================================================================= + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.skipif(_causal_conv1d_fn is None, reason="CUDA conv kernels required") +class TestBackendEquivalence: + """CUDA and CPU backends produce identical results.""" + + @pytest.mark.parametrize("seq_len", [1, 4, 8, 17, 32, 65]) + @pytest.mark.parametrize("batch_size", [1, 2, 4]) + def test_prefill_cuda_vs_cpu(self, conv, dim, seq_len, batch_size): + """CUDA prefill matches CPU prefill.""" + torch.manual_seed(123) + x = torch.randn(batch_size, dim, seq_len, device="cpu") + + # CPU + out_cpu = conv(x) + + # CUDA + conv_cuda = to_device(conv, "cuda") + out_cuda = conv_cuda(x.cuda()).cpu() + + torch.testing.assert_close(out_cuda, out_cpu, atol=1e-4, rtol=1e-4) + + @pytest.mark.parametrize("seq_len", [1, 4, 8, 17, 32]) + def test_prefill_with_state_cuda_vs_cpu(self, conv, dim, kernel_size, seq_len): + """CUDA prefill with state output matches CPU.""" + torch.manual_seed(123) + x = torch.randn(2, dim, seq_len, device="cpu") + + # CPU + out_cpu, state_cpu = prefill(conv, x) + + # CUDA + conv_cuda = to_device(conv, "cuda") + out_cuda, state_cuda = prefill(conv_cuda, x.cuda()) + out_cuda, state_cuda = out_cuda.cpu(), state_cuda.cpu() + + torch.testing.assert_close(out_cuda, out_cpu, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(state_cuda, state_cpu, atol=1e-5, rtol=1e-5) + + def test_decode_cuda_vs_cpu(self, conv, dim, kernel_size): + """CUDA single-token decode matches CPU.""" + torch.manual_seed(123) + token = torch.randn(2, dim, device="cpu") + state = torch.randn(2, dim, kernel_size - 1, device="cpu") + + # CPU + state_cpu = state.clone() + out_cpu = conv.update(token, state_cpu) + + # CUDA + conv_cuda = to_device(conv, "cuda") + state_cuda = state.cuda() + out_cuda = conv_cuda.update(token.cuda(), state_cuda).cpu() + state_cuda = state_cuda.cpu() + + torch.testing.assert_close(out_cuda, out_cpu, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(state_cuda, state_cpu, atol=1e-5, rtol=1e-5) + + +# ============================================================================= +# Chunking Consistency +# ============================================================================= + + +class TestChunkingConsistency: + """Chunked prefill matches full prefill.""" + + @pytest.mark.parametrize("total_len", [16, 33, 64]) + @pytest.mark.parametrize("chunk_size", [4, 7, 16]) + def test_chunked_prefill_cpu(self, conv, dim, total_len, chunk_size): + """CPU: Chunked prefill matches full prefill.""" + torch.manual_seed(123) + x = torch.randn(2, dim, total_len, device="cpu") + + # Reference: full prefill + ref_out, _ = prefill(conv, x) + + # Chunked prefill + outputs = [] + state = None + for start in range(0, total_len, chunk_size): + chunk = x[:, :, start:start + chunk_size] + out, state = prefill(conv, chunk, state) + outputs.append(out) + + chunked_out = torch.cat(outputs, dim=-1) + torch.testing.assert_close(chunked_out, ref_out, atol=1e-5, rtol=1e-5) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + @pytest.mark.skipif(_causal_conv1d_fn is None, reason="CUDA conv kernels required") + @pytest.mark.parametrize("total_len", [16, 33, 64]) + @pytest.mark.parametrize("chunk_size", [4, 7, 16]) + def test_chunked_prefill_cuda(self, conv, dim, total_len, chunk_size): + """CUDA: Chunked prefill matches full prefill.""" + torch.manual_seed(123) + x = torch.randn(2, dim, total_len, device="cpu") + + conv_cuda = to_device(conv, "cuda") + + # Reference: full prefill + ref_out, _ = prefill(conv_cuda, x.cuda()) + + # Chunked prefill + outputs = [] + state = None + for start in range(0, total_len, chunk_size): + chunk = x[:, :, start:start + chunk_size].cuda() + out, state = prefill(conv_cuda, chunk, state) + outputs.append(out) + + chunked_out = torch.cat(outputs, dim=-1) + torch.testing.assert_close(chunked_out, ref_out, atol=1e-4, rtol=1e-4) + + +# ============================================================================= +# Decode Consistency +# ============================================================================= + + +class TestDecodeConsistency: + """Token-by-token decode matches batch prefill.""" + + @pytest.mark.parametrize("prefill_len", [4, 8, 16]) + @pytest.mark.parametrize("decode_len", [1, 5, 10]) + def test_prefill_then_decode_cpu(self, conv, dim, prefill_len, decode_len): + """CPU: Prefill + decode matches full prefill.""" + torch.manual_seed(123) + total_len = prefill_len + decode_len + x = torch.randn(2, dim, total_len, device="cpu") + + # Reference: full prefill + ref_out, _ = prefill(conv, x) + + # Prefill prefix, then decode rest + out_prefix, state = prefill(conv, x[:, :, :prefill_len]) + out_decode, _ = decode_sequence(conv, x[:, :, prefill_len:], state) + + combined = torch.cat([out_prefix, out_decode], dim=-1) + torch.testing.assert_close(combined, ref_out, atol=1e-5, rtol=1e-5) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + @pytest.mark.skipif(_causal_conv1d_fn is None, reason="CUDA conv kernels required") + @pytest.mark.parametrize("prefill_len", [4, 8, 16]) + @pytest.mark.parametrize("decode_len", [1, 5, 10]) + def test_prefill_then_decode_cuda(self, conv, dim, prefill_len, decode_len): + """CUDA: Prefill + decode matches full prefill.""" + torch.manual_seed(123) + total_len = prefill_len + decode_len + x = torch.randn(2, dim, total_len, device="cuda") + + conv_cuda = to_device(conv, "cuda") + + # Reference: full prefill + ref_out, _ = prefill(conv_cuda, x) + + # Prefill prefix, then decode rest + out_prefix, state = prefill(conv_cuda, x[:, :, :prefill_len]) + out_decode, _ = decode_sequence(conv_cuda, x[:, :, prefill_len:], state) + + combined = torch.cat([out_prefix, out_decode], dim=-1) + torch.testing.assert_close(combined, ref_out, atol=1e-4, rtol=1e-4) + + +# ============================================================================= +# Global Consistency: The Ultimate Test +# ============================================================================= + + +class TestGlobalConsistency: + """ALL code paths must produce identical results for the same input.""" + + def test_all_cpu_paths_match(self, conv, dim): + """All CPU paths produce identical output.""" + torch.manual_seed(42) + + total_len = 24 + prefill_len = 16 + chunk_size = 8 + x = torch.randn(2, dim, total_len, device="cpu") + + # Reference: full prefill + reference, _ = prefill(conv, x) + + # Path 1: Chunked prefill + outputs = [] + state = None + for start in range(0, total_len, chunk_size): + chunk = x[:, :, start:start + chunk_size] + out, state = prefill(conv, chunk, state) + outputs.append(out) + path1 = torch.cat(outputs, dim=-1) + + # Path 2: Prefill + decode + out_prefix, state = prefill(conv, x[:, :, :prefill_len]) + out_decode, _ = decode_sequence(conv, x[:, :, prefill_len:], state) + path2 = torch.cat([out_prefix, out_decode], dim=-1) + + # Path 3: All decode (extreme case) + # Prefill first kernel_size-1 tokens, decode rest + init_len = conv.kernel_size[0] - 1 + out_init, state = prefill(conv, x[:, :, :init_len]) + out_decode, _ = decode_sequence(conv, x[:, :, init_len:], state) + path3 = torch.cat([out_init, out_decode], dim=-1) + + torch.testing.assert_close(path1, reference, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(path2, reference, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(path3, reference, atol=1e-5, rtol=1e-5) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + @pytest.mark.skipif(_causal_conv1d_fn is None, reason="CUDA conv kernels required") + def test_all_paths_match_cross_device(self, conv, dim): + """All paths (CPU and CUDA) produce identical output.""" + torch.manual_seed(42) + + total_len = 24 + prefill_len = 16 + chunk_size = 8 + x = torch.randn(2, dim, total_len, device="cpu") + + conv_cuda = to_device(conv, "cuda") + + # REFERENCE: CPU full prefill (simplest, most trustworthy) + reference, _ = prefill(conv, x) + + results = {} + + # CPU paths + # --------- + + # CPU chunked + outputs, state = [], None + for start in range(0, total_len, chunk_size): + out, state = prefill(conv, x[:, :, start:start + chunk_size], state) + outputs.append(out) + results["cpu_chunked"] = torch.cat(outputs, dim=-1) + + # CPU prefill + decode + out_prefix, state = prefill(conv, x[:, :, :prefill_len]) + out_decode, _ = decode_sequence(conv, x[:, :, prefill_len:], state) + results["cpu_prefill_decode"] = torch.cat([out_prefix, out_decode], dim=-1) + + # CUDA paths + # ---------- + + # CUDA full prefill + results["cuda_full"], _ = prefill(conv_cuda, x.cuda()) + results["cuda_full"] = results["cuda_full"].cpu() + + # CUDA chunked + outputs, state = [], None + for start in range(0, total_len, chunk_size): + out, state = prefill(conv_cuda, x[:, :, start:start + chunk_size].cuda(), state) + outputs.append(out.cpu()) + results["cuda_chunked"] = torch.cat(outputs, dim=-1) + + # CUDA prefill + decode + out_prefix, state = prefill(conv_cuda, x[:, :, :prefill_len].cuda()) + out_decode, _ = decode_sequence(conv_cuda, x[:, :, prefill_len:].cuda(), state) + results["cuda_prefill_decode"] = torch.cat([out_prefix.cpu(), out_decode.cpu()], dim=-1) + + # Mixed paths + # ----------- + + # CPU prefill, CUDA decode + out_prefix, state = prefill(conv, x[:, :, :prefill_len]) + state = state.cuda() + out_decode, _ = decode_sequence(conv_cuda, x[:, :, prefill_len:].cuda(), state) + results["cpu_prefill_cuda_decode"] = torch.cat([out_prefix, out_decode.cpu()], dim=-1) + + # CUDA prefill, CPU decode + out_prefix, state = prefill(conv_cuda, x[:, :, :prefill_len].cuda()) + out_prefix, state = out_prefix.cpu(), state.cpu() + out_decode, _ = decode_sequence(conv, x[:, :, prefill_len:], state) + results["cuda_prefill_cpu_decode"] = torch.cat([out_prefix, out_decode], dim=-1) + + # Verify all match reference + tolerances = { + "cpu_chunked": 1e-5, + "cpu_prefill_decode": 1e-5, + "cuda_full": 1e-4, + "cuda_chunked": 1e-4, + "cuda_prefill_decode": 1e-4, + "cpu_prefill_cuda_decode": 1e-4, + "cuda_prefill_cpu_decode": 1e-4, + } + + for name, result in results.items(): + tol = tolerances[name] + torch.testing.assert_close( + result, reference, atol=tol, rtol=tol, + msg=f"Path '{name}' diverged from reference" + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + @pytest.mark.skipif(_causal_conv1d_fn is None, reason="CUDA conv kernels required") + def test_long_decode_no_drift(self, conv, dim): + """Long decode sequence doesn't accumulate errors.""" + torch.manual_seed(42) + + prefill_len = 8 + decode_len = 100 # Long decode to catch drift + total_len = prefill_len + decode_len + x = torch.randn(2, dim, total_len, device="cpu") + + conv_cuda = to_device(conv, "cuda") + + # Reference: CPU full prefill + reference, _ = prefill(conv, x) + + # CUDA prefill + long decode + out_prefix, state = prefill(conv_cuda, x[:, :, :prefill_len].cuda()) + out_decode, _ = decode_sequence(conv_cuda, x[:, :, prefill_len:].cuda(), state) + result = torch.cat([out_prefix.cpu(), out_decode.cpu()], dim=-1) + + # Check max error at each position doesn't grow + errors = (result - reference).abs().max(dim=1).values.max(dim=0).values # [seq_len] + + # First positions should have small error + assert errors[:prefill_len].max() < 1e-4, "Prefill error too large" + + # Decode errors shouldn't grow unboundedly + # Allow slightly more tolerance for later positions but not exponential growth + assert errors[prefill_len:].max() < 1e-3, "Decode error too large" + + # Check no systematic drift (errors shouldn't consistently increase) + decode_errors = errors[prefill_len:] + first_half = decode_errors[:len(decode_errors)//2].mean() + second_half = decode_errors[len(decode_errors)//2:].mean() + assert second_half < first_half * 2, "Errors growing over decode steps (drift detected)" + + +# ============================================================================= +# Edge Cases +# ============================================================================= + + +class TestEdgeCases: + """Edge cases and boundary conditions.""" + + def test_single_token_prefill(self, conv, dim, kernel_size): + """Prefill with just 1 token works.""" + x = torch.randn(2, dim, 1, device="cpu") + out, state = prefill(conv, x) + + assert out.shape == (2, dim, 1) + assert state.shape == (2, dim, kernel_size - 1) + + def test_seq_shorter_than_kernel(self, conv, dim, kernel_size): + """Sequence shorter than kernel_size works.""" + seq_len = kernel_size - 2 # Shorter than kernel + x = torch.randn(2, dim, seq_len, device="cpu") + out, state = prefill(conv, x) + + assert out.shape == (2, dim, seq_len) + assert state.shape == (2, dim, kernel_size - 1) + + def test_seq_exactly_kernel_size(self, conv, dim, kernel_size): + """Sequence exactly kernel_size works.""" + x = torch.randn(2, dim, kernel_size, device="cpu") + out, state = prefill(conv, x) + + assert out.shape == (2, dim, kernel_size) + + def test_batch_size_one(self, conv, dim): + """Batch size 1 works.""" + x = torch.randn(1, dim, 16, device="cpu") + out, state = prefill(conv, x) + + assert out.shape == (1, dim, 16) + + def test_empty_decode_after_prefill(self, conv, dim, kernel_size): + """Zero decode steps after prefill is valid.""" + x = torch.randn(2, dim, 16, device="cpu") + out_prefill, state = prefill(conv, x) + + # No decode, just verify state is usable + token = torch.randn(2, dim, device="cpu") + out_token = conv.update(token, state) + assert out_token.shape == (2, dim) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + @pytest.mark.skipif(_causal_conv1d_fn is None, reason="CUDA conv kernels required") + def test_state_device_transfer(self, conv, dim, kernel_size): + """State can be transferred between devices.""" + x = torch.randn(2, dim, 16, device="cpu") + + # Prefill on CPU + _, state_cpu = prefill(conv, x) + + # Transfer state to CUDA + state_cuda = state_cpu.cuda() + conv_cuda = to_device(conv, "cuda") + + # Decode on CUDA with transferred state + token = torch.randn(2, dim, device="cuda") + out = conv_cuda.update(token, state_cuda) + + assert out.shape == (2, dim) + assert out.device.type == "cuda" diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py index a1d048d7a..0bd6ac88d 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py +++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py @@ -166,6 +166,31 @@ def test_cross_type_attention_to_mamba(self, source_config): assert mixer["d_state"] == 64 assert mixer["d_conv"] == 4 + def test_cross_type_attention_to_kda(self, source_config): + """attention→kda derives KDA dims from attention geometry.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "kda", + "init": "transfer", + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + }, + }, + }, + } + result = compose_configs(source_config, surgery) + + mixer = result["decoder"]["block"]["mixer"] + assert mixer["type"] == "kda" + # Derived from source attention geometry + assert mixer["heads"] == 8 # from heads + assert mixer["head_dim"] == 32 # from head_size + # From surgery + assert mixer["convolution_layer"]["kernel_size"] == 4 + assert mixer["normalization"]["epsilon"] == 1e-5 + def test_stochastic_submixer_inheritance(self, source_config): """Law 6: Sub-mixers inherit from base mixer when wrapping in stochastic.""" surgery = { diff --git a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py index 5e1a0c9db..c487ab3a3 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py +++ b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py @@ -23,7 +23,8 @@ fuse, full_slice, make_slice, - plan_attention_to_gated_delta_net, + plan_dil_attention_to_gdn, + plan_kil_attention_to_kda, plan_llava_to_apriel2, plan_mil_attention_to_mamba, plan_surgery, @@ -629,7 +630,6 @@ def test_plan_llava_is_all_refs(self, llava_pixtral_config): def test_plan_mil_attention_to_mamba(self): """MIL plan produces correct expressions.""" exprs = plan_mil_attention_to_mamba( - layer_idx=0, hidden_size=64, d_inner=128, d_xb=32, @@ -667,7 +667,6 @@ def test_plan_mil_attention_to_mamba(self): def test_plan_mil_execution(self): """MIL plan executes correctly with actual weights.""" plan = plan_mil_attention_to_mamba( - layer_idx=0, hidden_size=64, d_inner=128, d_xb=32, @@ -709,10 +708,10 @@ def test_plan_mil_execution(self): # out_proj should be 4.0 assert torch.allclose(result[W("mamba.out_proj.weight")], torch.full((64, 128), 4.0)) - def test_plan_attention_to_gated_delta_net(self): + def test_plan_dil_attention_to_gdn(self): """DIL plan produces correct per-head-group interleaved structure.""" # MHA case: num_v_heads == num_k_heads (no GQA), 1 v_head per group - plan = plan_attention_to_gated_delta_net( + plan = plan_dil_attention_to_gdn( hidden_size=64, num_v_heads=4, num_k_heads=4, @@ -797,11 +796,11 @@ def test_plan_attention_to_gated_delta_net(self): assert norm_weight.shape == (16,) # head_v_dim assert norm_weight.init_type == "ones" - def test_plan_attention_to_gated_delta_net_gqa(self): + def test_plan_dil_attention_to_gdn_gqa(self): """DIL plan handles GQA with tiling (not padding).""" # GQA case: 4 v_heads, 2 k_heads → 2 v_heads per group # Source has 4 Q heads, 2 KV heads - plan = plan_attention_to_gated_delta_net( + plan = plan_dil_attention_to_gdn( hidden_size=64, num_v_heads=4, num_k_heads=2, @@ -843,7 +842,7 @@ def test_plan_attention_to_gated_delta_net_gqa(self): def test_plan_dil_execution(self): """DIL plan executes correctly with FLAT layout [Q_all | K_all | V_all | Z_all].""" # MHA case: 4 k_heads, 4 v_heads (1 v_head per group) - plan = plan_attention_to_gated_delta_net( + plan = plan_dil_attention_to_gdn( hidden_size=64, num_v_heads=4, num_k_heads=4, @@ -954,7 +953,7 @@ def test_plan_dil_execution_gqa(self): """DIL plan executes correctly with GQA and FLAT layout.""" # GQA: 4 v_heads, 2 k_heads → 2 v_heads per group # Source: 4 Q heads, 2 KV heads - plan = plan_attention_to_gated_delta_net( + plan = plan_dil_attention_to_gdn( hidden_size=64, num_v_heads=4, num_k_heads=2, @@ -1025,6 +1024,159 @@ def test_plan_dil_execution_gqa(self): # Z_all (rows 128-191): zeros assert torch.allclose(in_proj_qkvz[2*key_dim+value_dim:], torch.zeros(value_dim, 64)) + def test_plan_kil_attention_to_kda(self): + """AIK plan produces correct structure for attention → KDA conversion.""" + plan = plan_kil_attention_to_kda( + hidden_size=64, + num_heads=4, + head_dim=16, + conv_kernel_size=4, + source_num_q_heads=4, + source_num_kv_heads=4, + source_head_dim=16, + source_prefix=W("attn"), + target_prefix=W(""), + ) + + # KDA has 15 weight tensors + assert len(plan.mappings) == 15 + + # Main projections transferred from attention + assert W("q_proj.weight") in plan.mappings + assert W("k_proj.weight") in plan.mappings + assert W("v_proj.weight") in plan.mappings + assert W("o_proj.weight") in plan.mappings + + # Convolutions (random init) + assert W("q_conv.weight") in plan.mappings + assert W("k_conv.weight") in plan.mappings + assert W("v_conv.weight") in plan.mappings + + # Gate kernels (random init) + assert W("f_a_proj.weight") in plan.mappings + assert W("f_b_proj.weight") in plan.mappings + assert W("g_a_proj.weight") in plan.mappings + assert W("g_b_proj.weight") in plan.mappings + + # Beta projection (random init) + assert W("beta_proj.weight") in plan.mappings + + # Learnable parameters + assert W("A_log") in plan.mappings + assert W("dt_bias") in plan.mappings + + # Normalization + assert W("norm.weight") in plan.mappings + + # Verify source refs for transferred weights + assert plan.mappings[W("q_proj.weight")].find_refs() == {W("attn.q_proj.weight")} + assert plan.mappings[W("o_proj.weight")].find_refs() == {W("attn.o_proj.weight")} + + # Verify random init weights have no refs + assert plan.mappings[W("q_conv.weight")].find_refs() == set() + assert plan.mappings[W("A_log")].find_refs() == set() + + def test_plan_kil_execution(self): + """AIK plan executes correctly for matching dimensions.""" + plan = plan_kil_attention_to_kda( + hidden_size=64, + num_heads=4, + head_dim=16, + conv_kernel_size=4, + source_num_q_heads=4, + source_num_kv_heads=4, + source_head_dim=16, + source_prefix=W("attn"), + target_prefix=W(""), + ) + + projection_size = 64 + + # Create attention weights + q_weight = torch.randn(projection_size, 64) + k_weight = torch.randn(projection_size, 64) + v_weight = torch.randn(projection_size, 64) + o_weight = torch.randn(64, projection_size) + + sources = { + W("attn.q_proj.weight"): q_weight, + W("attn.k_proj.weight"): k_weight, + W("attn.v_proj.weight"): v_weight, + W("attn.o_proj.weight"): o_weight, + } + + result = execute(plan, sources, seed=42) + + # Transferred weights should match exactly + assert torch.allclose(result[W("q_proj.weight")], q_weight) + assert torch.allclose(result[W("k_proj.weight")], k_weight) + assert torch.allclose(result[W("v_proj.weight")], v_weight) + assert torch.allclose(result[W("o_proj.weight")], o_weight) + + # Random init weights should have correct shapes + assert result[W("q_conv.weight")].shape == (projection_size, 1, 4) + assert result[W("k_conv.weight")].shape == (projection_size, 1, 4) + assert result[W("v_conv.weight")].shape == (projection_size, 1, 4) + assert result[W("f_a_proj.weight")].shape == (16, 64) # (head_dim, hidden_size) + assert result[W("f_b_proj.weight")].shape == (64, 16) # (projection_size, head_dim) + assert result[W("g_a_proj.weight")].shape == (16, 64) + assert result[W("g_b_proj.weight")].shape == (64, 16) + assert result[W("beta_proj.weight")].shape == (4, 64) # (num_heads, hidden_size) + assert result[W("A_log")].shape == (4,) # (num_heads,) + assert result[W("dt_bias")].shape == (projection_size,) # (projection_size,) + assert result[W("norm.weight")].shape == (16,) # (head_dim,) + + def test_plan_kil_execution_gqa(self): + """AIK plan executes correctly with GQA (tiling K/V from fewer source heads).""" + # Target: 4 heads (no GQA in KDA) + # Source: 4 Q heads, 2 KV heads (GQA) + plan = plan_kil_attention_to_kda( + hidden_size=64, + num_heads=4, + head_dim=16, + conv_kernel_size=4, + source_num_q_heads=4, + source_num_kv_heads=2, + source_head_dim=16, + source_prefix=W("attn"), + target_prefix=W(""), + ) + + # Create attention weights with distinct values per head + # Q: 4 heads, each head has value (head_idx + 1) + q_weight = torch.cat([torch.full((16, 64), float(i + 1)) for i in range(4)], dim=0) + # K: 2 heads, each head has value (head_idx + 1) * 10 + k_weight = torch.cat([torch.full((16, 64), float(i + 1) * 10) for i in range(2)], dim=0) + # V: 2 heads, each head has value (head_idx + 1) * 100 + v_weight = torch.cat([torch.full((16, 64), float(i + 1) * 100) for i in range(2)], dim=0) + + sources = { + W("attn.q_proj.weight"): q_weight, + W("attn.k_proj.weight"): k_weight, + W("attn.v_proj.weight"): v_weight, + W("attn.o_proj.weight"): torch.randn(64, 64), + } + + result = execute(plan, sources, seed=42) + + # Q: direct copy (4 heads → 4 heads) + assert torch.allclose(result[W("q_proj.weight")], q_weight) + + # K: tiled from 2 heads to 4 heads using modulo + # head 0 → src 0 (10), head 1 → src 1 (20), head 2 → src 0 (10), head 3 → src 1 (20) + k_result = result[W("k_proj.weight")] + assert torch.allclose(k_result[0:16], torch.full((16, 64), 10.0)) + assert torch.allclose(k_result[16:32], torch.full((16, 64), 20.0)) + assert torch.allclose(k_result[32:48], torch.full((16, 64), 10.0)) + assert torch.allclose(k_result[48:64], torch.full((16, 64), 20.0)) + + # V: same tiling pattern + v_result = result[W("v_proj.weight")] + assert torch.allclose(v_result[0:16], torch.full((16, 64), 100.0)) + assert torch.allclose(v_result[16:32], torch.full((16, 64), 200.0)) + assert torch.allclose(v_result[32:48], torch.full((16, 64), 100.0)) + assert torch.allclose(v_result[48:64], torch.full((16, 64), 200.0)) + class TestFullPipeline: """Test full conversion + surgery pipeline.""" diff --git a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py index 74bde087b..1aa8a56d9 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py @@ -1,11 +1,27 @@ """Tests for numerical equivalence between Apriel2 mixers and reference implementations. -Tests forward-pass equivalence between: -1. Apriel2Attention vs MistralAttention (using conversion machinery) -2. Apriel2Attention vs PixtralAttention (non-causal) -3. Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (using conversion machinery) - -Uses the apriel2/conversion module for weight transformations rather than hand-rolled copying. +This module verifies that Apriel2's mixer implementations produce outputs numerically +equivalent to their reference implementations (HuggingFace transformers, FLA, etc.). + +Test Categories: +================ +1. DETERMINISM - Verify same input → same output (no random variation) +2. EQUIVALENCE - Verify Apriel2 output matches reference implementation output +3. FAST/SLOW PATH - Verify CUDA kernels match PyTorch fallback + +Test Philosophy: +================ +- Equivalence tests use the apriel2/conversion module for weight transformations, + ensuring we test the same code paths used in production checkpoint conversion. +- Determinism tests use fixed seeds and verify bitwise equality. +- All tests use fp32 by default for numerical precision; bf16 is skipped for + correctness tests (would be used for performance benchmarks). + +Mixer Coverage: +=============== +- Attention: vs MistralAttention (causal), vs PixtralAttention (non-causal) +- GatedDeltaNet: vs Qwen3NextGatedDeltaNet +- KimiDeltaAttention: vs FLA KimiDeltaAttention """ import pytest @@ -13,54 +29,68 @@ import torch.nn as nn from fast_llm_external_models.apriel2.conversion import ( + Concat, ExprPlan, Ref, + Slice, W, execute, ) # ============================================================================= -# Fixtures for configs +# Shared Fixtures # ============================================================================= @pytest.fixture(params=[1, 2, 4]) def batch_size(request): - """Batch sizes to test.""" + """Batch sizes to test. Covers single-sample, small batch, and typical batch.""" return request.param @pytest.fixture(params=[1, 16, 64, 128]) def seq_len(request): - """Sequence lengths to test.""" + """Sequence lengths to test. + + - 1: Single token decode + - 16: Very short sequence + - 64: Typical sequence + - 128: Longer sequence (approaches chunk boundaries) + """ return request.param @pytest.fixture(params=[256, 512]) def hidden_size(request): - """Hidden sizes to test.""" + """Hidden sizes to test. 256 is minimal, 512 exercises larger matrices.""" return request.param @pytest.fixture( params=[ - (8, 8, 32), # MHA: 8 heads, 8 kv heads, 32 head_dim - (8, 4, 32), # GQA: 8 heads, 4 kv heads, 32 head_dim - (8, 2, 64), # GQA: 8 heads, 2 kv heads, 64 head_dim - (4, 1, 64), # MQA: 4 heads, 1 kv head, 64 head_dim + pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim + pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim + pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim + pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim ] ) def attention_config(request): - """Attention head configurations: (num_heads, num_kv_heads, head_dim).""" + """Attention head configurations: (num_heads, num_kv_heads, head_dim). + + Covers: + - MHA (multi-head attention): heads == kv_heads + - GQA (grouped query attention): heads > kv_heads + - MQA (multi-query attention): kv_heads == 1 + """ return request.param @pytest.fixture( params=[ - (8, 4, 32, 32), # 8 value heads, 4 key heads, 32 key_dim, 32 value_dim - (8, 2, 64, 64), # 8 value heads, 2 key heads, 64 key_dim, 64 value_dim - (4, 2, 32, 64), # 4 value heads, 2 key heads, 32 key_dim, 64 value_dim + pytest.param((8, 4, 32, 32), id="8v-4k-32d"), # 8 value heads, 4 key heads, symmetric dims + pytest.param((8, 2, 64, 64), id="8v-2k-64d"), # 8 value heads, 2 key heads, larger dims + pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims ] ) def gdn_config(request): @@ -68,17 +98,31 @@ def gdn_config(request): return request.param +@pytest.fixture( + params=[ + pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small) + pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium) + pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim) + ] +) +def kda_config(request): + """KDA configurations: (num_heads, head_dim).""" + return request.param + + # ============================================================================= -# Test Mode Fixtures (bundle device/dtype/attn_impl/tolerance coherently) +# Test Mode Configuration # ============================================================================= @pytest.fixture( params=[ "precise", - # "fast" mode (bf16/sdpa) is skipped: small tensor sizes in these tests - # make GPU overhead dominate, and precise mode is sufficient for correctness. - pytest.param("fast", marks=pytest.mark.skip(reason="Small tensors; precise mode sufficient")), + # "fast" mode (bf16/sdpa) is intentionally skipped: + # - These are correctness tests, not performance benchmarks + # - bf16 has ~3 decimal digits precision, masking real bugs + # - Small tensor sizes make GPU overhead dominate anyway + pytest.param("fast", marks=pytest.mark.skip(reason="Correctness tests use fp32")), ] ) def test_mode(request): @@ -88,17 +132,13 @@ def test_mode(request): @pytest.fixture def test_dtype(test_mode): - """Dtype derived from test_mode: fp32 for precise, bf16 for fast.""" + """Dtype derived from test_mode.""" return torch.float32 if test_mode == "precise" else torch.bfloat16 @pytest.fixture def attn_impl(test_mode): - """Attention implementation derived from test_mode. - - Uses PyTorch's SDPA (scaled_dot_product_attention) for fast mode, which - provides fused kernels without the special initialization flash_attention_2 needs. - """ + """Attention implementation derived from test_mode.""" return "eager" if test_mode == "precise" else "sdpa" @@ -106,23 +146,17 @@ def attn_impl(test_mode): def tolerance(test_mode): """Tolerance (rtol, atol) derived from test_mode. - bf16 has ~3 decimal digits precision, so needs looser tolerance. - fp32 "precise" mode uses 2e-4 to accommodate minor differences in - kernel implementations (e.g., fla vs pure PyTorch) while still - catching real bugs. + fp32 uses 2e-4 to accommodate minor kernel differences while catching real bugs. + bf16 would use 1e-2 due to ~3 decimal digit precision. """ - if test_mode == "precise": - return (2e-4, 2e-4) - else: - return (1e-2, 1e-2) + return (2e-4, 2e-4) if test_mode == "precise" else (1e-2, 1e-2) @pytest.fixture(autouse=True) def override_dtype_for_test_mode(test_mode): """Override default dtype based on test_mode. - This runs after conftest's set_default_dtype and temporarily changes - the dtype for tests that use test_mode. + Runs after conftest's set_default_dtype fixture. """ dtype = torch.float32 if test_mode == "precise" else torch.bfloat16 old_dtype = torch.get_default_dtype() @@ -132,25 +166,88 @@ def override_dtype_for_test_mode(test_mode): # ============================================================================= -# Helper functions +# Helper Functions # ============================================================================= -def assert_close(a: torch.Tensor, b: torch.Tensor, rtol: float = 1e-4, atol: float = 1e-4, msg: str = ""): - """Assert two tensors are close with detailed error message.""" - if not torch.allclose(a, b, rtol=rtol, atol=atol): - diff = (a - b).abs() +def assert_close( + actual: torch.Tensor, + expected: torch.Tensor, + rtol: float = 1e-4, + atol: float = 1e-4, + msg: str = "", +): + """Assert two tensors are close with detailed error diagnostics. + + Args: + actual: Tensor from implementation under test + expected: Tensor from reference implementation + rtol: Relative tolerance + atol: Absolute tolerance + msg: Context message for failure + """ + if not torch.allclose(actual, expected, rtol=rtol, atol=atol): + diff = (actual - expected).abs() max_diff = diff.max().item() mean_diff = diff.mean().item() + max_idx = diff.argmax().item() + raise AssertionError( + f"{msg}\n" + f" Max diff: {max_diff:.6e} at flat index {max_idx}\n" + f" Mean diff: {mean_diff:.6e}\n" + f" Tolerance: rtol={rtol}, atol={atol}\n" + f" Shapes: actual={actual.shape}, expected={expected.shape}" + ) + + +def assert_deterministic(out1: torch.Tensor, out2: torch.Tensor, mixer_name: str): + """Assert two outputs from same input are bitwise identical. + + Args: + out1: First forward pass output + out2: Second forward pass output + mixer_name: Name of mixer for error message + """ + if not torch.equal(out1, out2): + diff = (out1 - out2).abs() + max_diff = diff.max().item() + num_diff = (diff > 0).sum().item() raise AssertionError( - f"{msg}\nMax diff: {max_diff:.6e}, Mean diff: {mean_diff:.6e}, " f"rtol={rtol}, atol={atol}" + f"{mixer_name} output is not deterministic!\n" + f" {num_diff} elements differ (of {diff.numel()} total)\n" + f" Max difference: {max_diff:.6e}" ) +def extract_module_weights(module: nn.Module) -> dict[W, torch.Tensor]: + """Extract weights from a module as a dict with W keys for conversion plan.""" + weights = {} + for name, param in module.named_parameters(): + parts = name.split(".") + key = W(*parts) + weights[key] = param.data + return weights + + +def load_weights_into_module(module: nn.Module, weights: dict[W, torch.Tensor]): + """Load weights from conversion plan output into a module.""" + with torch.no_grad(): + for name, param in module.named_parameters(): + parts = name.split(".") + key = W(*parts) + if key in weights: + param.copy_(weights[key]) + + +# ============================================================================= +# Conversion Plans (Weight Transformations for Equivalence Tests) +# ============================================================================= + + def plan_mistral_attention_to_apriel2() -> ExprPlan: - """Build plan for MistralAttention -> Apriel2Attention weight renaming. + """MistralAttention -> Apriel2Attention weight mapping. - Both use q_proj/k_proj/v_proj/o_proj naming, so this is identity mapping. + Both use identical q_proj/k_proj/v_proj/o_proj naming, so this is identity. """ return ExprPlan( mappings={ @@ -168,54 +265,28 @@ def plan_qwen3next_gdn_to_apriel2( head_k_dim: int, head_v_dim: int, ) -> ExprPlan: - """Build plan for Qwen3NextGatedDeltaNet -> Apriel2GatedDeltaNet weight conversion. + """Qwen3NextGatedDeltaNet -> Apriel2GatedDeltaNet weight conversion. Qwen3Next uses GROUPED layout: for each key_head group, [Q_g | K_g | V_group | Z_group] Apriel2/Fast-LLM uses FLAT layout: [Q_all | K_all | V_all | Z_all] This plan rearranges in_proj_qkvz weights from grouped to flat layout. - Other weights are direct copies (with conv1d -> convolution rename). """ - from fast_llm_external_models.apriel2.conversion import Concat, Slice - - # Dimensions - key_dim = num_k_heads * head_k_dim - value_dim = num_v_heads * head_v_dim + # Dimensions per group v_per_group = (num_v_heads // num_k_heads) * head_v_dim group_size = head_k_dim * 2 + v_per_group * 2 # Q + K + V_group + Z_group qkvz_ref = Ref(key=W("in_proj_qkvz", "weight")) - # Extract Q, K, V, Z from each group and concatenate by type - q_slices = [] - k_slices = [] - v_slices = [] - z_slices = [] - + # Extract Q, K, V, Z from each group + q_slices, k_slices, v_slices, z_slices = [], [], [], [] for g in range(num_k_heads): base = g * group_size - # Q_g: [base, base + head_k_dim) q_slices.append(Slice(expr=qkvz_ref, slices=((base, base + head_k_dim, None), (None, None, None)))) - # K_g: [base + head_k_dim, base + 2*head_k_dim) - k_slices.append( - Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None))) - ) - # V_group_g: [base + 2*head_k_dim, base + 2*head_k_dim + v_per_group) - v_slices.append( - Slice( - expr=qkvz_ref, - slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None)), - ) - ) - # Z_group_g: [base + 2*head_k_dim + v_per_group, base + group_size) - z_slices.append( - Slice( - expr=qkvz_ref, - slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None)), - ) - ) + k_slices.append(Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None)))) + v_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None)))) + z_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None)))) - # Concatenate: [Q_all | K_all | V_all | Z_all] in_proj_qkvz_expr = Concat( exprs=( Concat(exprs=tuple(q_slices), dim=0), @@ -226,26 +297,18 @@ def plan_qwen3next_gdn_to_apriel2( dim=0, ) - # Similarly rearrange in_proj_ba: grouped [b_group | a_group] -> flat [b_all | a_all] + # Similarly rearrange in_proj_ba ba_ref = Ref(key=W("in_proj_ba", "weight")) - ba_per_group = (num_v_heads // num_k_heads) * 2 # b + a for the group + ba_per_group = (num_v_heads // num_k_heads) * 2 - b_slices = [] - a_slices = [] + b_slices, a_slices = [], [] for g in range(num_k_heads): base = g * ba_per_group - b_slices.append( - Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None))) - ) - a_slices.append( - Slice(expr=ba_ref, slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None))) - ) + b_slices.append(Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None)))) + a_slices.append(Slice(expr=ba_ref, slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None)))) in_proj_ba_expr = Concat( - exprs=( - Concat(exprs=tuple(b_slices), dim=0), - Concat(exprs=tuple(a_slices), dim=0), - ), + exprs=(Concat(exprs=tuple(b_slices), dim=0), Concat(exprs=tuple(a_slices), dim=0)), dim=0, ) @@ -254,7 +317,7 @@ def plan_qwen3next_gdn_to_apriel2( W("in_proj_qkvz", "weight"): in_proj_qkvz_expr, W("in_proj_ba", "weight"): in_proj_ba_expr, W("out_proj", "weight"): Ref(key=W("out_proj", "weight")), - W("convolution", "weight"): Ref(key=W("conv1d", "weight")), # rename + W("convolution", "weight"): Ref(key=W("conv1d", "weight")), W("dt_bias"): Ref(key=W("dt_bias")), W("A_log"): Ref(key=W("A_log")), W("norm", "weight"): Ref(key=W("norm", "weight")), @@ -262,42 +325,192 @@ def plan_qwen3next_gdn_to_apriel2( ) -def extract_module_weights(module: nn.Module) -> dict[W, torch.Tensor]: - """Extract weights from a module as a dict with W keys.""" - weights = {} - for name, param in module.named_parameters(): - # Convert "a.b.c" to W("a", "b", "c") - parts = name.split(".") - key = W(*parts) - weights[key] = param.data - return weights +def plan_fla_kda_to_apriel2() -> ExprPlan: + """FLA KimiDeltaAttention -> Apriel2 KimiDeltaAttention weight mapping. + Key renames: + - q_conv1d -> q_conv (same for k, v) + - f_proj.0/1 -> f_a_proj/f_b_proj + - g_proj.0/1 -> g_a_proj/g_b_proj + - b_proj -> beta_proj + - o_norm -> norm -def load_weights_into_module(module: nn.Module, weights: dict[W, torch.Tensor]): - """Load weights from a dict with W keys into a module.""" - with torch.no_grad(): - for name, param in module.named_parameters(): - parts = name.split(".") - key = W(*parts) - if key in weights: - param.copy_(weights[key]) + Note: FLA has bias on g_proj.1, Apriel2 doesn't. Test zeroes this bias. + """ + return ExprPlan( + mappings={ + # Projections (same names) + W("q_proj", "weight"): Ref(key=W("q_proj", "weight")), + W("k_proj", "weight"): Ref(key=W("k_proj", "weight")), + W("v_proj", "weight"): Ref(key=W("v_proj", "weight")), + W("o_proj", "weight"): Ref(key=W("o_proj", "weight")), + # Convolutions (conv1d -> conv) + W("q_conv", "weight"): Ref(key=W("q_conv1d", "weight")), + W("k_conv", "weight"): Ref(key=W("k_conv1d", "weight")), + W("v_conv", "weight"): Ref(key=W("v_conv1d", "weight")), + # Gate projections (Sequential -> separate) + W("f_a_proj", "weight"): Ref(key=W("f_proj", "0", "weight")), + W("f_b_proj", "weight"): Ref(key=W("f_proj", "1", "weight")), + W("g_a_proj", "weight"): Ref(key=W("g_proj", "0", "weight")), + W("g_b_proj", "weight"): Ref(key=W("g_proj", "1", "weight")), + # Beta (b_proj -> beta_proj) + W("beta_proj", "weight"): Ref(key=W("b_proj", "weight")), + # Learnable params + W("A_log"): Ref(key=W("A_log")), + W("dt_bias"): Ref(key=W("dt_bias")), + # Normalization (o_norm -> norm) + W("norm", "weight"): Ref(key=W("o_norm", "weight")), + } + ) + + +# ============================================================================= +# SECTION 1: DETERMINISM TESTS +# ============================================================================= + + +class TestDeterminism: + """Verify mixers produce deterministic outputs. + + These tests run the same input through a mixer twice and verify + bitwise-identical outputs. Non-determinism would indicate: + - Uncontrolled randomness in kernels + - Race conditions in parallel operations + - Floating-point non-associativity issues + """ + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") + def test_attention_determinism(self, attention_config): + """Verify Apriel2Attention produces identical output on repeated calls.""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention + + num_heads, num_kv_heads, head_dim = attention_config + hidden_size = 256 + batch_size, seq_len = 2, 32 + + mixer_config = { + "type": "attention", + "heads": num_heads, + "head_groups": num_kv_heads, + "head_size": head_dim, + "add_linear_biases": False, + "causal": True, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + } + + config = Apriel2TextConfig( + hidden_size=hidden_size, + decoder={ + "type": "fixed", + "num_blocks": 1, + "block": { + "mixer": mixer_config, + "mlp": {"type": "mlp", "intermediate_size": hidden_size * 4}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + embeddings={"max_position_embeddings": 4096}, + ) + config._attn_implementation = "eager" + + torch.manual_seed(42) + model = Apriel2Attention(hidden_size, mixer_config, layer_idx=0, config=config) + model.eval() + + torch.manual_seed(123) + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) + + rotary_resources = Apriel2Attention.setup(mixer_config, hidden_size, 4096) + position_embeddings = rotary_resources["rotary_emb"](hidden_states, position_ids) + + with torch.no_grad(): + out1 = model(hidden_states, position_embeddings=position_embeddings)[0] + out2 = model(hidden_states, position_embeddings=position_embeddings)[0] + + assert_deterministic(out1, out2, "Apriel2Attention") + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GDN requires CUDA") + def test_gdn_determinism(self, gdn_config): + """Verify Apriel2GatedDeltaNet produces identical output on repeated calls.""" + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet + + value_heads, key_heads, key_head_dim, value_head_dim = gdn_config + hidden_size = 256 + batch_size, seq_len = 2, 32 + + config_dict = { + "type": "gdn", + "value_heads": value_heads, + "key_heads": key_heads, + "key_head_dim": key_head_dim, + "value_head_dim": value_head_dim, + "convolution_layer": {"kernel_size": 4}, + "norm_eps": 1e-5, + } + + torch.manual_seed(42) + model = Apriel2GatedDeltaNet(hidden_size, config_dict, layer_idx=0) + model.eval() + + torch.manual_seed(123) + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + + with torch.no_grad(): + out1 = model(hidden_states)[0] + out2 = model(hidden_states)[0] + + assert_deterministic(out1, out2, "Apriel2GatedDeltaNet") + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA requires CUDA") + def test_kda_determinism(self, kda_config): + """Verify Apriel2 KimiDeltaAttention produces identical output on repeated calls.""" + from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention + + num_heads, head_dim = kda_config + hidden_size = num_heads * head_dim + batch_size, seq_len = 2, 32 + + config_dict = { + "type": "kda", + "heads": num_heads, + "head_dim": head_dim, + "convolution_layer": {"kernel_size": 4}, + "normalization": {"epsilon": 1e-5}, + } + + torch.manual_seed(42) + model = KimiDeltaAttention(hidden_size, config_dict, layer_idx=0) + model.eval() + + torch.manual_seed(123) + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + + with torch.no_grad(): + out1 = model(hidden_states)[0] + out2 = model(hidden_states)[0] + + assert_deterministic(out1, out2, "KimiDeltaAttention") # ============================================================================= -# Apriel2Attention vs MistralAttention Tests +# SECTION 2: EQUIVALENCE TESTS - Attention # ============================================================================= -class TestApriel2AttentionVsMistral: - """Test equivalence between Apriel2Attention and MistralAttention.""" +class TestAttentionEquivalence: + """Verify Apriel2Attention matches reference attention implementations. + + Tests both causal (vs Mistral) and non-causal (vs Pixtral) modes. + """ @pytest.fixture def mistral_config(self, hidden_size, attention_config, attn_impl): - """Create MistralConfig for testing.""" + """Create MistralConfig for causal attention testing.""" from transformers import MistralConfig num_heads, num_kv_heads, head_dim = attention_config - config = MistralConfig( hidden_size=hidden_size, num_attention_heads=num_heads, @@ -311,32 +524,26 @@ def mistral_config(self, hidden_size, attention_config, attn_impl): return config @pytest.fixture - def apriel2_mixer_config(self, attention_config): - """Create Apriel2 mixer config dict.""" - num_heads, num_kv_heads, head_dim = attention_config - - return { - "type": "attention", - "heads": num_heads, - "head_groups": num_kv_heads, - "head_size": head_dim, - "add_linear_biases": False, - "causal": True, - "rotary": {"type": "mistral_1d", "theta": 10000.0}, - } - - @pytest.fixture - def apriel2_config(self, hidden_size, apriel2_mixer_config, attn_impl): - """Create Apriel2Config for testing.""" + def apriel2_config(self, hidden_size, attention_config, attn_impl): + """Create Apriel2Config for causal attention testing.""" from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig + num_heads, num_kv_heads, head_dim = attention_config config = Apriel2TextConfig( hidden_size=hidden_size, decoder={ "type": "fixed", "num_blocks": 1, "block": { - "mixer": apriel2_mixer_config, + "mixer": { + "type": "attention", + "heads": num_heads, + "head_groups": num_kv_heads, + "head_size": head_dim, + "add_linear_biases": False, + "causal": True, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + }, "mlp": {"type": "mlp", "intermediate_size": hidden_size * 4}, "normalization": {"type": "rms_norm", "epsilon": 1e-5}, }, @@ -347,42 +554,38 @@ def apriel2_config(self, hidden_size, apriel2_mixer_config, attn_impl): return config @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") - def test_forward_equivalence( + def test_causal_vs_mistral( self, mistral_config, apriel2_config, - apriel2_mixer_config, batch_size, seq_len, hidden_size, tolerance, ): - """Test that Apriel2Attention produces same output as MistralAttention.""" + """Verify Apriel2Attention (causal) matches MistralAttention output.""" from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRotaryEmbedding from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention - # Create models (uses default device/dtype from fixtures) + mixer_config = apriel2_config.decoder["block"]["mixer"] + + # Create models mistral_attn = MistralAttention(mistral_config, layer_idx=0) - apriel2_attn = Apriel2Attention(hidden_size, apriel2_mixer_config, layer_idx=0, config=apriel2_config) + apriel2_attn = Apriel2Attention(hidden_size, mixer_config, layer_idx=0, config=apriel2_config) - # Use conversion machinery to transfer weights + # Transfer weights using conversion plan plan = plan_mistral_attention_to_apriel2() source_weights = extract_module_weights(mistral_attn) target_weights = execute(plan, source_weights, seed=42) load_weights_into_module(apriel2_attn, target_weights) - # Create input + # Create inputs torch.manual_seed(42) hidden_states = torch.randn(batch_size, seq_len, hidden_size) - - # Create position_ids position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) - - # Create causal mask causal_mask = torch.triu(torch.full((seq_len, seq_len), float("-inf")), diagonal=1).unsqueeze(0).unsqueeze(0) - # Compute position embeddings using Mistral's rotary embedding - # Use the same position embeddings for both to ensure equivalence test is fair + # Compute rotary embeddings mistral_rotary = MistralRotaryEmbedding(config=mistral_config) position_embeddings = mistral_rotary(hidden_states, position_ids) @@ -390,68 +593,50 @@ def test_forward_equivalence( apriel2_attn.eval() with torch.no_grad(): - # Mistral forward - position_embeddings is now a required positional arg - mistral_out = mistral_attn( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=causal_mask, - )[0] - - # Apriel2 forward - use the same position embeddings - apriel2_out = apriel2_attn( - hidden_states, - attention_mask=causal_mask, - position_embeddings=position_embeddings, - )[0] + mistral_out = mistral_attn(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)[0] + apriel2_out = apriel2_attn(hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings)[0] rtol, atol = tolerance assert_close( - apriel2_out, - mistral_out, - rtol=rtol, - atol=atol, - msg=f"Apriel2Attention vs MistralAttention mismatch " - f"(batch={batch_size}, seq={seq_len}, hidden={hidden_size})", + apriel2_out, mistral_out, rtol=rtol, atol=atol, + msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})" ) - -# ============================================================================= -# Apriel2Attention vs PixtralAttention Tests (non-causal) -# ============================================================================= - - -class TestApriel2AttentionVsPixtral: - """Test equivalence between Apriel2Attention and PixtralAttention (non-causal). - - Note: Full 2D rotary equivalence tests are in test_rotary_2d_equivalence.py. - This test focuses on verifying the attention mechanism itself is equivalent - when given the same inputs. - """ - - @pytest.fixture - def pixtral_config(self, attention_config, attn_impl): - """Create PixtralVisionConfig for testing.""" + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") + @pytest.mark.parametrize("seq_len", [16, 64]) # Must be perfect squares for 2D position + def test_noncausal_vs_pixtral( + self, + attention_config, + batch_size, + seq_len, + attn_impl, + tolerance, + ): + """Verify Apriel2Attention (non-causal) matches PixtralAttention output.""" + from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding from transformers.models.pixtral.configuration_pixtral import PixtralVisionConfig + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention num_heads, _, head_dim = attention_config hidden_size = num_heads * head_dim - config = PixtralVisionConfig( + # Verify seq_len is perfect square + grid_size = int(seq_len**0.5) + if grid_size * grid_size != seq_len: + pytest.skip(f"seq_len {seq_len} is not a perfect square for 2D position test") + + # Create configs + pixtral_config = PixtralVisionConfig( hidden_size=hidden_size, num_attention_heads=num_heads, intermediate_size=hidden_size * 4, num_hidden_layers=1, rope_theta=10000.0, ) - config._attn_implementation = attn_impl - return config - - @pytest.fixture - def apriel2_mixer_config_noncausal(self, attention_config): - """Create Apriel2 mixer config dict for non-causal attention.""" - num_heads, _, head_dim = attention_config + pixtral_config._attn_implementation = attn_impl - return { + mixer_config = { "type": "attention", "heads": num_heads, "head_groups": num_heads, # Pixtral uses MHA @@ -461,38 +646,13 @@ def apriel2_mixer_config_noncausal(self, attention_config): "rotary": {"type": "pixtral_2d", "theta": 10000.0, "patch_size": 16, "max_image_size": 1024}, } - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") - @pytest.mark.parametrize("seq_len", [16, 64]) # Override to use specific lengths for vision - def test_forward_equivalence_noncausal( - self, - pixtral_config, - apriel2_mixer_config_noncausal, - attention_config, - batch_size, - seq_len, - attn_impl, - tolerance, - ): - """Test that Apriel2Attention (non-causal) produces same output as PixtralAttention. - - This test creates 1D position embeddings in the format both implementations expect, - allowing us to verify the core attention mechanism is equivalent. - """ - from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig - - num_heads, _, head_dim = attention_config - hidden_size = num_heads * head_dim - - # Create Apriel2 config apriel2_config = Apriel2TextConfig( hidden_size=hidden_size, decoder={ "type": "fixed", "num_blocks": 1, "block": { - "mixer": apriel2_mixer_config_noncausal, + "mixer": mixer_config, "mlp": {"type": "mlp", "intermediate_size": hidden_size * 4}, "normalization": {"type": "rms_norm", "epsilon": 1e-5}, }, @@ -501,81 +661,55 @@ def test_forward_equivalence_noncausal( ) apriel2_config._attn_implementation = attn_impl - # Create models (uses default device/dtype from conftest fixtures) + # Create models pixtral_attn = PixtralAttention(pixtral_config) - apriel2_attn = Apriel2Attention( - hidden_size, apriel2_mixer_config_noncausal, layer_idx=0, config=apriel2_config - ) + apriel2_attn = Apriel2Attention(hidden_size, mixer_config, layer_idx=0, config=apriel2_config) - # Use conversion machinery to transfer weights (Pixtral uses same naming as Mistral) + # Transfer weights plan = plan_mistral_attention_to_apriel2() source_weights = extract_module_weights(pixtral_attn) target_weights = execute(plan, source_weights, seed=42) load_weights_into_module(apriel2_attn, target_weights) - # Create input + # Create inputs torch.manual_seed(42) hidden_states = torch.randn(batch_size, seq_len, hidden_size) - # For 2D rotary, we need position_ids that represent 2D positions - # Simulate a small image grid - grid_size = int(seq_len**0.5) - if grid_size * grid_size != seq_len: - pytest.skip(f"seq_len {seq_len} is not a perfect square for 2D position test") - rotary_emb = PixtralRotaryEmbedding(config=pixtral_config) position_ids = torch.arange(seq_len) cos, sin = rotary_emb(hidden_states, position_ids) - # Add batch dimension for compatibility with both Pixtral and Apriel2 (Mistral) conventions position_embeddings = (cos.unsqueeze(0), sin.unsqueeze(0)) pixtral_attn.eval() apriel2_attn.eval() with torch.no_grad(): - # Pixtral forward with explicit position embeddings - pixtral_out = pixtral_attn( - hidden_states, - attention_mask=None, - position_embeddings=position_embeddings, - )[0] - - # Apriel2 forward with same position embeddings - apriel2_out = apriel2_attn( - hidden_states, - attention_mask=None, - position_embeddings=position_embeddings, - )[0] + pixtral_out = pixtral_attn(hidden_states, attention_mask=None, position_embeddings=position_embeddings)[0] + apriel2_out = apriel2_attn(hidden_states, attention_mask=None, position_embeddings=position_embeddings)[0] rtol, atol = tolerance assert_close( - apriel2_out, - pixtral_out, - rtol=rtol, - atol=atol, - msg=f"Apriel2Attention (non-causal) vs PixtralAttention mismatch " - f"(batch={batch_size}, seq={seq_len}, hidden={hidden_size})", + apriel2_out, pixtral_out, rtol=rtol, atol=atol, + msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})" ) # ============================================================================= -# Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet Tests +# SECTION 2: EQUIVALENCE TESTS - GatedDeltaNet # ============================================================================= -class TestApriel2GDNVsQwen3Next: - """Test equivalence between Apriel2GatedDeltaNet and Qwen3NextGatedDeltaNet.""" +class TestGDNEquivalence: + """Verify Apriel2GatedDeltaNet matches Qwen3NextGatedDeltaNet.""" @pytest.fixture def qwen3_config(self, hidden_size, gdn_config): - """Create Qwen3NextConfig for testing.""" + """Create Qwen3NextConfig for GDN testing.""" from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig value_heads, key_heads, key_head_dim, value_head_dim = gdn_config - return Qwen3NextConfig( hidden_size=hidden_size, - # Qwen3NextConfig uses different param names for GDN: linear_num_value_heads=value_heads, linear_num_key_heads=key_heads, linear_key_head_dim=key_head_dim, @@ -583,65 +717,55 @@ def qwen3_config(self, hidden_size, gdn_config): linear_conv_kernel_dim=4, rms_norm_eps=1e-5, max_position_embeddings=4096, - # Attention params (not used for GDN but required) num_attention_heads=8, num_key_value_heads=2, head_dim=64, - # Explicitly set dtype to avoid torch.get_current_dtype() fallback torch_dtype=torch.get_default_dtype(), ) - @pytest.fixture - def apriel2_gdn_config(self, gdn_config): - """Create Apriel2 GDN config dict.""" - value_heads, key_heads, key_head_dim, value_head_dim = gdn_config - - return { - "type": "gdn", - "value_heads": value_heads, - "key_heads": key_heads, - "key_head_dim": key_head_dim, - "value_head_dim": value_head_dim, - "convolution_layer": {"kernel_size": 4}, - "norm_eps": 1e-5, - } - @pytest.mark.skipif(not torch.cuda.is_available(), reason="GDN requires CUDA") - @pytest.mark.parametrize("seed", [42, 123, 456, 789, 1337]) - def test_forward_equivalence( + @pytest.mark.parametrize("seed", [42, 123, 456]) + def test_vs_qwen3next( self, qwen3_config, - apriel2_gdn_config, - hidden_size, gdn_config, + hidden_size, batch_size, seq_len, seed, tolerance, ): - """Test that Apriel2GatedDeltaNet produces same output as Qwen3NextGatedDeltaNet.""" + """Verify Apriel2GatedDeltaNet matches Qwen3NextGatedDeltaNet output.""" from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet value_heads, key_heads, key_head_dim, value_head_dim = gdn_config - # Create models with different random seeds for weight initialization + config_dict = { + "type": "gdn", + "value_heads": value_heads, + "key_heads": key_heads, + "key_head_dim": key_head_dim, + "value_head_dim": value_head_dim, + "convolution_layer": {"kernel_size": 4}, + "norm_eps": 1e-5, + } + + # Create models torch.manual_seed(seed) qwen_gdn = Qwen3NextGatedDeltaNet(qwen3_config, layer_idx=0) - apriel2_gdn = Apriel2GatedDeltaNet(hidden_size, apriel2_gdn_config, layer_idx=0) + apriel2_gdn = Apriel2GatedDeltaNet(hidden_size, config_dict, layer_idx=0) - # Use conversion machinery to transfer weights (handles layout differences) + # Transfer weights plan = plan_qwen3next_gdn_to_apriel2( - num_k_heads=key_heads, - num_v_heads=value_heads, - head_k_dim=key_head_dim, - head_v_dim=value_head_dim, + num_k_heads=key_heads, num_v_heads=value_heads, + head_k_dim=key_head_dim, head_v_dim=value_head_dim, ) source_weights = extract_module_weights(qwen_gdn) target_weights = execute(plan, source_weights, seed=seed) load_weights_into_module(apriel2_gdn, target_weights) - # Create input with same seed for reproducibility + # Create input torch.manual_seed(seed) hidden_states = torch.randn(batch_size, seq_len, hidden_size) @@ -649,155 +773,120 @@ def test_forward_equivalence( apriel2_gdn.eval() with torch.no_grad(): - # Qwen3NextGatedDeltaNet returns tensor directly, Apriel2 returns tuple qwen_out = qwen_gdn(hidden_states) apriel2_out = apriel2_gdn(hidden_states)[0] rtol, atol = tolerance assert_close( - apriel2_out, - qwen_out, - rtol=rtol, - atol=atol, - msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet mismatch " - f"(batch={batch_size}, seq={seq_len}, hidden={hidden_size})", + apriel2_out, qwen_out, rtol=rtol, atol=atol, + msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})" ) # ============================================================================= -# Fast Path vs Slow Path Tests +# SECTION 2: EQUIVALENCE TESTS - KimiDeltaAttention # ============================================================================= -class TestFastVsSlowPath: - """Test that fast path (CUDA kernels) and slow path (PyTorch) produce same results.""" - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") - def test_gdn_fast_vs_slow_path(self, gdn_config, batch_size): - """Test GDN produces same output with fast path vs slow path.""" - from fast_llm_external_models.apriel2.modeling_apriel2 import ( - Apriel2GatedDeltaNet, - chunk_gated_delta_rule, - torch_chunk_gated_delta_rule, - ) +class TestKDAEquivalence: + """Verify Apriel2 KimiDeltaAttention matches FLA KimiDeltaAttention.""" - if chunk_gated_delta_rule is None: - pytest.skip("Fast path (fla) not available") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA requires CUDA") + @pytest.mark.parametrize("seed", [42, 123, 456]) + def test_vs_fla( + self, + kda_config, + batch_size, + seq_len, + seed, + tolerance, + ): + """Verify Apriel2 KimiDeltaAttention matches FLA KimiDeltaAttention output.""" + from fla.layers.kda import KimiDeltaAttention as FLA_KDA + from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention as Apriel2_KDA - value_heads, key_heads, key_head_dim, value_head_dim = gdn_config - hidden_size = 256 - seq_len = 32 + num_heads, head_dim = kda_config + hidden_size = num_heads * head_dim - gdn_config_dict = { - "type": "gdn", - "value_heads": value_heads, - "key_heads": key_heads, - "key_head_dim": key_head_dim, - "value_head_dim": value_head_dim, + config_dict = { + "type": "kda", + "heads": num_heads, + "head_dim": head_dim, "convolution_layer": {"kernel_size": 4}, - "norm_eps": 1e-5, + "normalization": {"epsilon": 1e-5}, } - # Create model (uses default device/dtype from conftest fixtures) - torch.manual_seed(42) - model = Apriel2GatedDeltaNet(hidden_size, gdn_config_dict, layer_idx=0) + # Create FLA KDA + torch.manual_seed(seed) + fla_kda = FLA_KDA( + hidden_size=hidden_size, + num_heads=num_heads, + head_dim=head_dim, + conv_size=4, + conv_bias=False, + norm_eps=1e-5, + layer_idx=0, + ) + # FLA has g_proj.1 bias=True but Apriel2/upstream Kimi doesn't - zero it out + fla_kda.g_proj[1].bias.data.zero_() + + # Create Apriel2 KDA + apriel2_kda = Apriel2_KDA(hidden_size, config_dict, layer_idx=0) + + # Transfer weights + plan = plan_fla_kda_to_apriel2() + source_weights = extract_module_weights(fla_kda) + target_weights = execute(plan, source_weights, seed=seed) + load_weights_into_module(apriel2_kda, target_weights) # Create input - torch.manual_seed(123) + torch.manual_seed(seed) hidden_states = torch.randn(batch_size, seq_len, hidden_size) - model.eval() + fla_kda.eval() + apriel2_kda.eval() - # Run with fast path with torch.no_grad(): - model._chunk_gated_delta_rule = chunk_gated_delta_rule - fast_out = model(hidden_states)[0].clone() + # use_cache=True ensures FLA initializes conv cache for short sequences + fla_out = fla_kda(hidden_states, use_cache=True)[0] + apriel2_out = apriel2_kda(hidden_states)[0] - # Run with slow path - with torch.no_grad(): - model._chunk_gated_delta_rule = torch_chunk_gated_delta_rule - slow_out = model(hidden_states)[0].clone() - - assert_close(fast_out, slow_out, rtol=1e-3, atol=1e-3, msg="Fast path vs slow path mismatch for GDN") + rtol, atol = tolerance + assert_close( + apriel2_out, fla_out, rtol=rtol, atol=atol, + msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})" + ) # ============================================================================= -# Determinism Tests +# SECTION 3: FAST PATH vs SLOW PATH TESTS # ============================================================================= -class TestDeterminism: - """Test that models produce deterministic outputs.""" - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") - def test_attention_determinism(self, attention_config): - """Test Apriel2Attention produces deterministic output.""" - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig - - num_heads, num_kv_heads, head_dim = attention_config - hidden_size = 256 - batch_size = 2 - seq_len = 32 +class TestFastVsSlowPath: + """Verify CUDA kernel outputs match PyTorch fallback outputs. - mixer_config = { - "type": "attention", - "heads": num_heads, - "head_groups": num_kv_heads, - "head_size": head_dim, - "add_linear_biases": False, - "causal": True, - "rotary": {"type": "mistral_1d", "theta": 10000.0}, - } + These tests ensure the optimized CUDA kernels (from fla-core) produce + the same results as the pure PyTorch implementations used on CPU or + when CUDA kernels are unavailable. + """ - config = Apriel2TextConfig( - hidden_size=hidden_size, - decoder={ - "type": "fixed", - "num_blocks": 1, - "block": { - "mixer": mixer_config, - "mlp": {"type": "mlp", "intermediate_size": hidden_size * 4}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - embeddings={"max_position_embeddings": 4096}, + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") + def test_gdn_fast_vs_slow(self, gdn_config, batch_size): + """Verify GDN CUDA kernel matches PyTorch fallback.""" + from fast_llm_external_models.apriel2.modeling_apriel2 import ( + Apriel2GatedDeltaNet, + chunk_gated_delta_rule, + torch_chunk_gated_delta_rule, ) - config._attn_implementation = "eager" - # Create model with fixed seed (uses default device/dtype from conftest fixtures) - torch.manual_seed(42) - model = Apriel2Attention(hidden_size, mixer_config, layer_idx=0, config=config) - model.eval() - - # Create input with fixed seed - torch.manual_seed(123) - hidden_states = torch.randn(batch_size, seq_len, hidden_size) - position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) - - # Get rotary embeddings - rotary_resources = Apriel2Attention.setup(mixer_config, hidden_size, 4096) - rotary_emb = rotary_resources["rotary_emb"] - position_embeddings = rotary_emb(hidden_states, position_ids) - - # Run twice - with torch.no_grad(): - out1 = model(hidden_states, position_embeddings=position_embeddings)[0] - out2 = model(hidden_states, position_embeddings=position_embeddings)[0] - - assert torch.equal(out1, out2), "Attention output is not deterministic" - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="GDN requires CUDA") - def test_gdn_determinism(self, gdn_config): - """Test Apriel2GatedDeltaNet produces deterministic output.""" - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet + if chunk_gated_delta_rule is None: + pytest.skip("Fast path (fla) not available") value_heads, key_heads, key_head_dim, value_head_dim = gdn_config - hidden_size = 256 - batch_size = 2 - seq_len = 32 + hidden_size, seq_len = 256, 32 - gdn_config_dict = { + config_dict = { "type": "gdn", "value_heads": value_heads, "key_heads": key_heads, @@ -807,18 +896,24 @@ def test_gdn_determinism(self, gdn_config): "norm_eps": 1e-5, } - # Create model with fixed seed (uses default device/dtype from conftest fixtures) torch.manual_seed(42) - model = Apriel2GatedDeltaNet(hidden_size, gdn_config_dict, layer_idx=0) + model = Apriel2GatedDeltaNet(hidden_size, config_dict, layer_idx=0) model.eval() - # Create input with fixed seed torch.manual_seed(123) hidden_states = torch.randn(batch_size, seq_len, hidden_size) - # Run twice with torch.no_grad(): - out1 = model(hidden_states)[0] - out2 = model(hidden_states)[0] + # Fast path (CUDA kernel) + model._chunk_gated_delta_rule = chunk_gated_delta_rule + fast_out = model(hidden_states)[0].clone() + + # Slow path (PyTorch fallback) + model._chunk_gated_delta_rule = torch_chunk_gated_delta_rule + slow_out = model(hidden_states)[0].clone() - assert torch.equal(out1, out2), "GDN output is not deterministic" + # Looser tolerance for kernel vs reference comparison + assert_close( + fast_out, slow_out, rtol=1e-3, atol=1e-3, + msg="GDN fast path (CUDA) vs slow path (PyTorch)" + ) diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py b/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py index 0ba6a4628..3b4adc7f5 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py +++ b/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py @@ -966,6 +966,7 @@ def test_plan_config_consistency_comprehensive( class TestPlanCompositionWithRealYAML: """Test plan composition using real YAML surgery files.""" + @requires_cuda def test_stochastic_supernet_yaml_end_to_end(self, llava_pixtral_checkpoint): """Test full pipeline with stochastic_supernet.yaml.""" import yaml diff --git a/tests/layers/test_gdn_equivalence.py b/tests/layers/test_gdn_equivalence.py index 803d2eaac..dae4f52b2 100644 --- a/tests/layers/test_gdn_equivalence.py +++ b/tests/layers/test_gdn_equivalence.py @@ -1,29 +1,49 @@ +"""Test numerical equivalence between Fast-LLM GDN and Apriel2 GatedDeltaNet.""" + import pytest import torch from fast_llm.config import UpdateType from fast_llm.layers.block.config import BlockKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet +from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda +try: + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet +except ImportError: + Apriel2GatedDeltaNet = None + +try: + from fla.ops.gated_delta_rule import chunk_gated_delta_rule + + _gdn_kernel_available = True +except ImportError: + _gdn_kernel_available = False + +# Test constants VOCAB_SIZE = 500 -HIDDEN_SIZE = 16 +HIDDEN_SIZE = 64 SEQ_LEN = 65 +BATCH_SIZE = 2 NUM_V_HEADS = 4 NUM_K_HEADS = 2 -HEAD_DIM = 4 +HEAD_DIM = 16 KERNEL_SIZE = 4 @pytest.mark.slow @requires_cuda +@pytest.mark.skipif(Apriel2GatedDeltaNet is None, reason="Apriel2 GDN not available") +@pytest.mark.skipif(not _gdn_kernel_available, reason="GDN CUDA kernels not available") def test_fast_llm_gdn_matches_apriel2_forward(): - torch.manual_seed(0) + """Verify Fast-LLM GDN output matches Apriel2 GatedDeltaNet.""" + torch.manual_seed(42) device = torch.device("cuda") dtype = torch.bfloat16 - config_gdn = { + # Create Apriel2 GDN layer + gdn_config = { "value_heads": NUM_V_HEADS, "key_heads": NUM_K_HEADS, "key_head_dim": HEAD_DIM, @@ -31,11 +51,10 @@ def test_fast_llm_gdn_matches_apriel2_forward(): "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, "norm_eps": 1e-5, } + hf_layer = Apriel2GatedDeltaNet(HIDDEN_SIZE, gdn_config, layer_idx=0, dtype=dtype).to(device=device, dtype=dtype) + hf_layer.eval() - hf_layer = ( - Apriel2GatedDeltaNet(HIDDEN_SIZE, config_gdn, layer_idx=0, dtype=dtype).to(device=device, dtype=dtype).eval() - ) - + # Create Fast-LLM GDN layer config = GPTBaseModelConfig.from_dict( { "decoder": { @@ -68,35 +87,35 @@ def test_fast_llm_gdn_matches_apriel2_forward(): ) fast_layer = model.decoder[0].mixer get_stage([fast_layer], distributed, [], {}) - fast_layer.to(device=device, dtype=dtype).eval() - - with torch.no_grad(): - fast_layer.in_proj_qkvz.weight.copy_(hf_layer.in_proj_qkvz.weight) - fast_layer.in_proj_ba.weight.copy_(hf_layer.in_proj_ba.weight) - fast_layer.convolution.weight.copy_(hf_layer.convolution.weight) - if fast_layer.convolution.bias is not None and hf_layer.convolution.bias is not None: - fast_layer.convolution.bias.copy_(hf_layer.convolution.bias) - fast_layer.out_proj.weight.copy_(hf_layer.out_proj.weight) - fast_layer.A_log.copy_(hf_layer.A_log) - fast_layer.dt_bias.copy_(hf_layer.dt_bias) - fast_layer.norm.weight.copy_(hf_layer.norm.weight) - - hidden_states = torch.randn(1, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) - hf_state_dict = hf_layer.state_dict() - for k, p in fast_layer.state_dict().items(): - torch.testing.assert_close(p, hf_state_dict[k], atol=1e-5, rtol=1e-5) + fast_layer.to(device=device, dtype=dtype) + fast_layer.eval() + + # Copy weights: parameter names match exactly, so use load_state_dict + hf_layer.load_state_dict(fast_layer.state_dict()) + + # Verify all parameters match + hf_state = hf_layer.state_dict() + for name, fast_param in fast_layer.state_dict().items(): + assert name in hf_state, f"Parameter {name} missing in HF layer" + hf_param = hf_state[name] + if fast_param.shape != hf_param.shape: + hf_param = hf_param.reshape_as(fast_param) + Assert.all_equal(fast_param, hf_param) + + # Forward passes + hidden_states = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) hf_out = hf_layer(hidden_states)[0] - sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] fast_kwargs = { BlockKwargs.device: device, BlockKwargs.sequence_first: False, BlockKwargs.hidden_dims: (HIDDEN_SIZE,), BlockKwargs.sequence_length: SEQ_LEN, - BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.sequence_lengths: [[SEQ_LEN] for _ in range(BATCH_SIZE)], } fast_layer.preprocess(fast_kwargs) fast_out, _ = fast_layer(hidden_states, fast_kwargs) - torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) + # Compare outputs + Assert.rms_close(fast_out, hf_out, 1e-5) diff --git a/tests/layers/test_kda_equivalence.py b/tests/layers/test_kda_equivalence.py index 8745236d4..fb0042c45 100644 --- a/tests/layers/test_kda_equivalence.py +++ b/tests/layers/test_kda_equivalence.py @@ -1,3 +1,5 @@ +"""Test numerical equivalence between Fast-LLM KDA and Apriel2 KimiDeltaAttention.""" + import pytest import torch @@ -5,55 +7,52 @@ from fast_llm.config import UpdateType from fast_llm.layers.block.config import BlockKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda try: - from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig - from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention + from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention except ImportError: - AprielHybridSSMConfig, KimiDeltaAttention = None, None + KimiDeltaAttention = None +# Test constants VOCAB_SIZE = 500 -HIDDEN_SIZE = 16 +HIDDEN_SIZE = 64 SEQ_LEN = 65 +BATCH_SIZE = 2 NUM_HEADS = 4 -HEAD_DIM = 4 +HEAD_DIM = 16 KERNEL_SIZE = 4 @pytest.mark.slow @requires_cuda -@pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") +@pytest.mark.skipif(KimiDeltaAttention is None, reason="Apriel2 KDA not available") @pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") -def test_fast_llm_kda_matches_apriel_forward(): - torch.manual_seed(0) +def test_fast_llm_kda_matches_apriel2_forward(): + """Verify Fast-LLM KDA output matches Apriel2 KimiDeltaAttention.""" + torch.manual_seed(42) device = torch.device("cuda") dtype = torch.bfloat16 - hf_config = AprielHybridSSMConfig( - hidden_size=HIDDEN_SIZE, - num_attention_heads=NUM_HEADS, - num_hidden_layers=1, - rms_norm_eps=1e-6, - ) - hf_config.short_conv_kernel_size = KERNEL_SIZE - hf_config.head_dim = HEAD_DIM - hf_config.num_heads = NUM_HEADS - hf_layer = KimiDeltaAttention(hf_config, layer_idx=0).to(device=device, dtype=dtype).eval() + # Shared config - parameter names match exactly between implementations + kda_config = { + "heads": NUM_HEADS, + "head_dim": HEAD_DIM, + "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, + "normalization": {"epsilon": 1e-5, "activation": "sigmoid"}, + } + # Create Apriel2 KDA layer + hf_layer = KimiDeltaAttention(HIDDEN_SIZE, kda_config, layer_idx=0).to(device=device, dtype=dtype) + hf_layer.eval() + + # Create Fast-LLM KDA layer config = GPTBaseModelConfig.from_dict( { "decoder": { "num_blocks": 1, - "block": { - "mixer": { - "type": "kda", - "heads": NUM_HEADS, - "head_dim": HEAD_DIM, - "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, - "normalization": {"epsilon": hf_config.rms_norm_eps, "activation": "sigmoid"}, - } - }, + "block": {"mixer": {"type": "kda", **kda_config}}, }, "embeddings": {"vocab_size": VOCAB_SIZE}, "hidden_size": HIDDEN_SIZE, @@ -71,68 +70,30 @@ def test_fast_llm_kda_matches_apriel_forward(): ) fast_layer = model.decoder[0].mixer get_stage([fast_layer], distributed, [], {}) - fast_layer.to(device=device, dtype=dtype).eval() - - with torch.no_grad(): - fast_layer.q_proj.weight.copy_(hf_layer.q_proj.weight) - fast_layer.k_proj.weight.copy_(hf_layer.k_proj.weight) - fast_layer.v_proj.weight.copy_(hf_layer.v_proj.weight) - fast_layer.q_conv.weight.copy_(hf_layer.q_conv1d.weight) - fast_layer.k_conv.weight.copy_(hf_layer.k_conv1d.weight) - fast_layer.v_conv.weight.copy_(hf_layer.v_conv1d.weight) - if fast_layer.q_conv.bias is not None and hf_layer.q_conv1d.bias is not None: - fast_layer.q_conv.bias.copy_(hf_layer.q_conv1d.bias) - if fast_layer.k_conv.bias is not None and hf_layer.k_conv1d.bias is not None: - fast_layer.k_conv.bias.copy_(hf_layer.k_conv1d.bias) - if fast_layer.v_conv.bias is not None and hf_layer.v_conv1d.bias is not None: - fast_layer.v_conv.bias.copy_(hf_layer.v_conv1d.bias) - fast_layer.f_a_proj.weight.copy_(hf_layer.f_a_proj.weight) - fast_layer.f_b_proj.weight.copy_(hf_layer.f_b_proj.weight) - fast_layer.g_a_proj.weight.copy_(hf_layer.g_a_proj.weight) - fast_layer.g_b_proj.weight.copy_(hf_layer.g_b_proj.weight) - fast_layer.beta_proj.weight.copy_(hf_layer.b_proj.weight) - fast_layer.o_proj.weight.copy_(hf_layer.o_proj.weight) - fast_layer.A_log.copy_(hf_layer.A_log.reshape_as(fast_layer.A_log)) - fast_layer.dt_bias.copy_(hf_layer.dt_bias.reshape_as(fast_layer.dt_bias)) - fast_layer.norm.weight.copy_(hf_layer.o_norm.weight) - - param_map = { - "q_proj.weight": "q_proj.weight", - "k_proj.weight": "k_proj.weight", - "v_proj.weight": "v_proj.weight", - "q_conv.weight": "q_conv1d.weight", - "k_conv.weight": "k_conv1d.weight", - "v_conv.weight": "v_conv1d.weight", - "f_a_proj.weight": "f_a_proj.weight", - "f_b_proj.weight": "f_b_proj.weight", - "g_a_proj.weight": "g_a_proj.weight", - "g_b_proj.weight": "g_b_proj.weight", - "beta_proj.weight": "b_proj.weight", - "o_proj.weight": "o_proj.weight", - "A_log": "A_log", - "dt_bias": "dt_bias", - "norm.weight": "o_norm.weight", - } - for fast_name, hf_name in param_map.items(): - fast_param = fast_layer.state_dict()[fast_name] - hf_param = hf_layer.state_dict()[hf_name] - if fast_param.shape != hf_param.shape: - hf_param = hf_param.reshape_as(fast_param) - print(f"Comparing parameter {fast_name} with shape {fast_param.shape}") - torch.testing.assert_close(fast_param, hf_param, atol=1e-5, rtol=1e-5) - - hidden_states = torch.randn(2, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) - hf_layer.training = True - hf_out = hf_layer(hidden_states) - - sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] + fast_layer.to(device=device, dtype=dtype) + fast_layer.eval() + + # Copy weights: parameter names match exactly, so use load_state_dict + hf_layer.load_state_dict(fast_layer.state_dict()) + + # Verify all parameters match + hf_state = hf_layer.state_dict() + for name, fast_param in fast_layer.state_dict().items(): + Assert.all_equal(fast_param, hf_state[name]) + + # Forward passes + hidden_states = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) + + hf_out = hf_layer(hidden_states)[0] + fast_kwargs = { BlockKwargs.device: device, BlockKwargs.sequence_first: False, - BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.sequence_lengths: [[SEQ_LEN] for _ in range(BATCH_SIZE)], BlockKwargs.hidden_dims: (HIDDEN_SIZE,), } fast_layer.preprocess(fast_kwargs) fast_out, _ = fast_layer(hidden_states, fast_kwargs) - torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) + # Compare outputs + Assert.rms_close(fast_out, hf_out, 1e-5) diff --git a/tests/layers/test_mamba_equivalence.py b/tests/layers/test_mamba_equivalence.py new file mode 100644 index 000000000..ccf2dba41 --- /dev/null +++ b/tests/layers/test_mamba_equivalence.py @@ -0,0 +1,175 @@ +"""Test numerical equivalence between Fast-LLM Mamba2 and Apriel2 Mamba. + +Note: Fast-LLM's "mamba_2" type is actually a Mamba 1 variant (not the true Mamba 2 +architecture). It corresponds to the HuggingFace/Apriel Mamba implementation. +""" + +import pytest +import torch + +from fast_llm.config import UpdateType +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.ssm.config import Mamba2Config # Ensures mamba_2 type is registered +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.utils import Assert +from tests.utils.utils import get_base_model, get_stage, requires_cuda + +# Ensure Mamba2Config is registered for dynamic type lookup +_ = Mamba2Config + +try: + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Mamba +except ImportError: + Apriel2Mamba = None + +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + + _mamba_kernel_available = True +except (ImportError, RuntimeError): + _mamba_kernel_available = False + +# Test constants +VOCAB_SIZE = 500 +HIDDEN_SIZE = 64 +SEQ_LEN = 65 +BATCH_SIZE = 2 +D_INNER = 128 +D_XB = 64 +D_STATE = 16 +D_CONV = 4 +DT_RANK = 4 + + +def _copy_weights(fast_layer, hf_layer): + """Copy weights from Apriel2 Mamba to Fast-LLM Mamba2.""" + with torch.no_grad(): + # Main projections + fast_layer.in_proj.weight.copy_(hf_layer.in_proj.weight) + if fast_layer.in_proj.bias is not None and hf_layer.in_proj.bias is not None: + fast_layer.in_proj.bias.copy_(hf_layer.in_proj.bias) + + # DT projections + fast_layer.dt_in_proj.weight.copy_(hf_layer.dt_in_proj.weight) + if fast_layer.dt_in_proj.bias is not None and hf_layer.dt_in_proj.bias is not None: + fast_layer.dt_in_proj.bias.copy_(hf_layer.dt_in_proj.bias) + + fast_layer.dt_proj.weight.copy_(hf_layer.dt_proj.weight) + if fast_layer.dt_proj.bias is not None and hf_layer.dt_proj.bias is not None: + fast_layer.dt_proj.bias.copy_(hf_layer.dt_proj.bias) + + # Convolution (Fast-LLM uses "convolution", Apriel2 uses "conv1d") + fast_layer.convolution.weight.copy_(hf_layer.conv1d.weight) + if fast_layer.convolution.bias is not None and hf_layer.conv1d.bias is not None: + fast_layer.convolution.bias.copy_(hf_layer.conv1d.bias) + + # SSM parameters + fast_layer.A_log.copy_(hf_layer.A_log) + fast_layer.D.copy_(hf_layer.D) + + # Output projection + fast_layer.out_proj.weight.copy_(hf_layer.out_proj.weight) + if fast_layer.out_proj.bias is not None and hf_layer.out_proj.bias is not None: + fast_layer.out_proj.bias.copy_(hf_layer.out_proj.bias) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.skipif(Apriel2Mamba is None, reason="Apriel2 Mamba not available") +@pytest.mark.skipif(not _mamba_kernel_available, reason="Mamba CUDA kernels not available") +@pytest.mark.parametrize("add_linear_biases", [True, False]) +@pytest.mark.parametrize("repeat_kv_before_conv", [True, False]) +def test_fast_llm_mamba2_matches_apriel2(add_linear_biases, repeat_kv_before_conv): + """Verify Fast-LLM Mamba2 output matches Apriel2 Mamba. + + Args: + add_linear_biases: Whether to add biases to linear layers. + repeat_kv_before_conv: Whether to repeat KV before or after convolution. + """ + torch.manual_seed(42) + device = torch.device("cuda") + dtype = torch.bfloat16 + + # Create Apriel2 Mamba layer + # Note: Apriel2 has separate conv_bias and dt_proj_bias controls. + # We align them with Fast-LLM's single add_linear_biases flag. + mamba_config = { + "d_inner": D_INNER, + "d_xb": D_XB, + "state_size": D_STATE, + "d_conv": D_CONV, + "dt_rank": DT_RANK, + "conv_bias": add_linear_biases, + "dt_proj_bias": add_linear_biases, + "add_linear_biases": add_linear_biases, + "repeat_kv_before_conv": repeat_kv_before_conv, + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init_floor": 1e-4, + } + hf_layer = Apriel2Mamba(HIDDEN_SIZE, mamba_config, layer_idx=0, dtype=dtype).to(device=device, dtype=dtype) + hf_layer.eval() + + # Create Fast-LLM Mamba2 layer + config = GPTBaseModelConfig.from_dict( + { + "decoder": { + "num_blocks": 1, + "block": { + "mixer": { + "type": "mamba_2", + "d_inner": D_INNER, + "d_xb": D_XB, + "state_size": D_STATE, + "convolution_layer": {"kernel_size": D_CONV}, + "dt_rank": DT_RANK, + "add_linear_biases": add_linear_biases, + "repeat_kv_before_conv": repeat_kv_before_conv, + } + }, + }, + "embeddings": {"vocab_size": VOCAB_SIZE}, + "hidden_size": HIDDEN_SIZE, + }, + update_type=UpdateType.update, + ) + + model, distributed = get_base_model( + GPTModelConfig.from_dict( + { + "base_model": config, + "distributed": {}, + }, + ) + ) + fast_layer = model.decoder[0].mixer + get_stage([fast_layer], distributed, [], {}) + fast_layer.to(device=device, dtype=dtype) + fast_layer.eval() + + # Copy weights + _copy_weights(fast_layer, hf_layer) + + # Verify key parameters match (not all names match between implementations) + Assert.all_equal(fast_layer.in_proj.weight, hf_layer.in_proj.weight) + Assert.all_equal(fast_layer.convolution.weight, hf_layer.conv1d.weight) + Assert.all_equal(fast_layer.A_log, hf_layer.A_log) + Assert.all_equal(fast_layer.D, hf_layer.D) + Assert.all_equal(fast_layer.out_proj.weight, hf_layer.out_proj.weight) + + # Forward passes + hidden_states = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) + + hf_out = hf_layer(hidden_states)[0] + + fast_kwargs = { + BlockKwargs.device: device, + BlockKwargs.sequence_first: False, + BlockKwargs.sequence_lengths: [[SEQ_LEN] for _ in range(BATCH_SIZE)], + BlockKwargs.hidden_dims: (HIDDEN_SIZE,), + } + fast_layer.preprocess(fast_kwargs) + fast_out, _ = fast_layer(hidden_states, fast_kwargs) + + # Compare outputs (slightly looser tolerance for Mamba due to numerical differences) + Assert.rms_close(fast_out, hf_out, 1e-4) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index fac595905..53373e0ca 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -56,7 +56,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon _bf16_compare = get_config( sub_configs={ ("init", None): get_config(), - (None, "fw"): get_config(1e-2, 1e-3), + (None, "fw"): get_config(1.5e-2, 1.5e-3), (None, "bw"): get_config(1.5e-2, 1e-5), (None, "bias"): get_config(2e-2, 1e-3), (None, "gradient"): get_config(2e-2, 5e-5), diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 9231168aa..e943dc96a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -202,6 +202,8 @@ def _update_and_add_testing_config( "save": True, "show": False, }, + # Uncomment to enable model debug logging: + # "model_debug_level": _LOG_LEVEL, }, "training": { "logs": {"interval": 1}, @@ -929,6 +931,11 @@ def _update_and_add_testing_config( "d_xb": 256, "add_linear_biases": False, }, + "kda": { + "type": "kda", + "heads": 4, + "head_dim": 16, + }, }, "sampling_strategy": "uniform", "main_mixer_name": "attn", @@ -956,9 +963,17 @@ def _update_and_add_testing_config( "value_head_dim": 16, }, }, + "kda": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "kda", + "heads": 4, + "head_dim": 16, + }, + }, }, - "pattern": ["attn_full", "mamba", "stochastic", "attn_swa", "gdn", "stochastic"], - "num_blocks": 6, + "pattern": ["attn_full", "mamba", "stochastic", "attn_swa", "gdn", "kda", "stochastic"], + "num_blocks": 7, }, }, megatron_args=None, @@ -1018,7 +1033,9 @@ def _update_and_add_testing_config( compare_factor=6.0, # Micro-sequence split and sequence-first not supported for Mamba. # TP excluded because no gradient reductions implemented for TP norm in GDN (use STP instead). - skip_tests=("sdp", "ms", "bf4", "df4", TP_NO_STP), + # bf2_df2 depends on df4, so must also be skipped. + skip_tests=("sdp", "ms", "bf4", "df4", "bf2_df2", TP_NO_STP), + auto_model_class=transformers.AutoModelForImageTextToText, )