Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions fast_llm/layers/block/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ def _get_meta(
return None
if dims is None:
dims = tuple(f"dim_{i}" for i in range(tensor.ndim))
hidden_dims = {
dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],)
}
hidden_dims = {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are required kwargs, why would they be missing?

if BlockKwargs.hidden_dims in kwargs:
for dim in kwargs[BlockKwargs.hidden_dims]:
hidden_dims[dim.name] = dim
if BlockKwargs.sequence_q_dim in kwargs:
hidden_dims[kwargs[BlockKwargs.sequence_q_dim].name] = kwargs[BlockKwargs.sequence_q_dim]
return TensorMeta.from_dims(
tuple(
(
Expand Down
5 changes: 4 additions & 1 deletion fast_llm/layers/decoder/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,8 @@ def _forward(
transposed_layer_2_weight=self.layer_2.transposed_weight,
)
bias = self.layer_2.bias if self._parallel_dim.group else None
self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs, bias=bias)
# Use None for dims when output_dim differs from hidden_dim (e.g., adapter projections)
# to let _debug infer dims from actual tensor shape
dims = None if self._output_dim != self._hidden_dim else kwargs.get(BlockKwargs.hidden_dims)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work, it will produce incorrect results in distributed settings

self._debug(out, None, dims, kwargs, bias=bias)
return out, bias
8 changes: 7 additions & 1 deletion fast_llm/layers/decoder/stochastic_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,13 @@ def _forward(
mixer_name = self._sample_mixer_name(kwargs)

if get_model_debug_level() > 0:
logger.debug(f"StochasticMixer selecting mixer {mixer_name}: {type(self.mixers[mixer_name]).__name__}")
from fast_llm.layers.block.config import BlockKwargs

iteration = kwargs.get(BlockKwargs.iteration, "?")
logger.info(
f"StochasticMixer iter={iteration} selecting mixer '{mixer_name}' "
f"({type(self.mixers[mixer_name]).__name__})"
)

return self.mixers[mixer_name]._forward(input_, kwargs, losses, metrics)

Expand Down
4 changes: 3 additions & 1 deletion fast_llm/layers/vision/vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __init__(
peft: PeftConfig | None,
):
super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft)
vision_hidden_dim = TensorDim("hidden", self._config.hidden_size)
# Internal hidden dimension for embeddings and encoder (may differ from output hidden_dim for adapter)
self._vision_hidden_dim = TensorDim("hidden", self._config.hidden_size)
Comment thread
jlamypoirier marked this conversation as resolved.
vision_hidden_dim = self._vision_hidden_dim
self.embeddings = self._config.embeddings.get_layer(
distributed_config,
vision_hidden_dim,
Expand Down
154 changes: 153 additions & 1 deletion fast_llm/models/gpt/conversion/apriel2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig
from fast_llm.layers.ssm.config import GatedDeltaNetConfig, Mamba2Config
from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, Mamba2Config
from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig
from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
Expand Down Expand Up @@ -271,6 +271,144 @@ def get_converters(
]


class Apriel2KimiDeltaAttentionConverter:
@classmethod
def import_config(cls, config: dict) -> dict:
result = {
"type": "kda",
"heads": config["heads"],
"head_dim": config["head_dim"],
}
if "convolution_layer" in config:
result["convolution_layer"] = config["convolution_layer"]
if "normalization" in config:
result["normalization"] = config["normalization"]
return result

@classmethod
def export_config(cls, config: KimiDeltaAttentionConfig) -> dict:
return {
"type": "kda",
"heads": config.heads,
"head_dim": config.head_dim,
"convolution_layer": {
"kernel_size": config.convolution_layer.kernel_size,
},
"normalization": {
"epsilon": config.normalization.epsilon,
},
}

@classmethod
def get_converters(
cls,
config: KimiDeltaAttentionConfig,
fast_llm_prefix: str,
hf_prefix: str,
drop_on_export: bool = False,
) -> list[WeightConverter]:
# Fast-LLM KDA uses abbreviated names matching the external module:
# q_proj, k_proj, v_proj, q_conv, k_conv, v_conv, f_a_proj, f_b_proj,
# g_a_proj, g_b_proj, beta_proj, o_proj, A_log, dt_bias, norm
return [
# Q/K/V projections
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.q_proj",
f"{hf_prefix}.q_proj",
False,
drop_on_export=drop_on_export,
),
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.k_proj",
f"{hf_prefix}.k_proj",
False,
drop_on_export=drop_on_export,
),
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.v_proj",
f"{hf_prefix}.v_proj",
False,
drop_on_export=drop_on_export,
),
# Convolutions (Q, K, V)
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.q_conv",
f"{hf_prefix}.q_conv",
False,
drop_on_export=drop_on_export,
),
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.k_conv",
f"{hf_prefix}.k_conv",
False,
drop_on_export=drop_on_export,
),
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.v_conv",
f"{hf_prefix}.v_conv",
False,
drop_on_export=drop_on_export,
),
# Gate projections (f_a, f_b, g_a, g_b)
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.f_a_proj",
f"{hf_prefix}.f_a_proj",
False,
drop_on_export=drop_on_export,
),
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.f_b_proj",
f"{hf_prefix}.f_b_proj",
False,
drop_on_export=drop_on_export,
),
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.g_a_proj",
f"{hf_prefix}.g_a_proj",
False,
drop_on_export=drop_on_export,
),
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.g_b_proj",
f"{hf_prefix}.g_b_proj",
False,
drop_on_export=drop_on_export,
),
# Beta projection
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.beta_proj",
f"{hf_prefix}.beta_proj",
False,
drop_on_export=drop_on_export,
),
# Output projection
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.o_proj",
f"{hf_prefix}.o_proj",
False,
drop_on_export=drop_on_export,
),
# Learnable parameters
get_parameter_converter(
f"{fast_llm_prefix}.A_log",
f"{hf_prefix}.A_log",
drop_on_export=drop_on_export,
),
get_parameter_converter(
f"{fast_llm_prefix}.dt_bias",
f"{hf_prefix}.dt_bias",
drop_on_export=drop_on_export,
),
# Normalization
*LlamaNormalizationConverter.get_converters(
config.normalization,
f"{fast_llm_prefix}.norm",
f"{hf_prefix}.norm",
drop_on_export=drop_on_export,
),
]


class Apriel2StochasticMixerConverter:
@classmethod
def import_config(cls, config: dict) -> dict:
Expand All @@ -283,6 +421,8 @@ def import_config(cls, config: dict) -> dict:
mixers[name] = Apriel2MambaConverter.import_config(sub_mixer_config)
elif mixer_type == "gdn":
mixers[name] = Apriel2GatedDeltaNetConverter.import_config(sub_mixer_config)
elif mixer_type == "kda":
mixers[name] = Apriel2KimiDeltaAttentionConverter.import_config(sub_mixer_config)
else:
raise ValueError(f"Unknown sub-mixer type: {mixer_type}")

Expand All @@ -306,6 +446,8 @@ def export_config(cls, config: StochasticMixerConfig) -> dict:
mixers[name] = Apriel2MambaConverter.export_config(sub_mixer)
elif mixer_type is GatedDeltaNetConfig:
mixers[name] = Apriel2GatedDeltaNetConverter.export_config(sub_mixer)
elif mixer_type is KimiDeltaAttentionConfig:
mixers[name] = Apriel2KimiDeltaAttentionConverter.export_config(sub_mixer)
else:
raise ValueError(f"Unknown sub-mixer type: {mixer_type}")

Expand Down Expand Up @@ -336,6 +478,9 @@ def get_converters(
elif mixer_type is GatedDeltaNetConfig:
converter_class = Apriel2GatedDeltaNetConverter
hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}"
elif mixer_type is KimiDeltaAttentionConfig:
converter_class = Apriel2KimiDeltaAttentionConverter
hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}"
else:
raise ValueError(f"Unknown sub-mixer type: {mixer_type}")
converters.extend(
Expand Down Expand Up @@ -364,6 +509,8 @@ def import_config(cls, config: dict, block_config: dict) -> dict:
mixer = Apriel2StochasticMixerConverter.import_config(mixer_config)
elif mixer_type == "gdn":
mixer = Apriel2GatedDeltaNetConverter.import_config(mixer_config)
elif mixer_type == "kda":
mixer = Apriel2KimiDeltaAttentionConverter.import_config(mixer_config)
else:
raise ValueError(f"Unknown mixer type: {mixer_type}")

Expand Down Expand Up @@ -404,6 +551,8 @@ def export_config(cls, config: DecoderBlockConfig) -> dict:
mixer = Apriel2StochasticMixerConverter.export_config(config.mixer)
elif mixer_type is GatedDeltaNetConfig:
mixer = Apriel2GatedDeltaNetConverter.export_config(config.mixer)
elif mixer_type is KimiDeltaAttentionConfig:
mixer = Apriel2KimiDeltaAttentionConverter.export_config(config.mixer)
else:
raise ValueError(f"Unknown mixer type: {mixer_type}")

Expand Down Expand Up @@ -460,6 +609,9 @@ def get_converters(
elif mixer_type is GatedDeltaNetConfig:
converter_class = Apriel2GatedDeltaNetConverter
hf_mixer_prefix = f"{hf_prefix}.mixer"
elif mixer_type is KimiDeltaAttentionConfig:
converter_class = Apriel2KimiDeltaAttentionConverter
hf_mixer_prefix = f"{hf_prefix}.mixer"
else:
raise ValueError(f"Unknown mixer type: {mixer_type}")

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/models/multimodal/conversion/apriel2.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def _export_config(cls, config: MultiModalModelConfig) -> dict[str, typing.Any]:
"auto_map": {
"AutoConfig": "configuration_apriel2.Apriel2Config",
"AutoModel": "modeling_apriel2.Apriel2Model",
"AutoModelForCausalLM": "modeling_apriel2.Apriel2ForConditionalGeneration",
"AutoModelForImageTextToText": "modeling_apriel2.Apriel2ForConditionalGeneration",
},
},
)
Expand Down
5 changes: 3 additions & 2 deletions fast_llm/models/multimodal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,11 @@ def preprocess_meta(
TensorDim("patch_width", self._config.vision_encoder.embeddings.patch_width),
)
)
# Use vision encoder's internal hidden dim (for embeddings/encoder), not the output dim (for adapter)
hidden_dims = (
(hidden_batch_and_sequence_q_dim, scalar_dim, self.vision_encoder._hidden_dim)
(hidden_batch_and_sequence_q_dim, scalar_dim, self.vision_encoder._vision_hidden_dim)
if (sequence_first := kwargs[LanguageModelKwargs.sequence_first])
else (scalar_dim, hidden_batch_and_sequence_q_dim, self.vision_encoder._hidden_dim)
else (scalar_dim, hidden_batch_and_sequence_q_dim, self.vision_encoder._vision_hidden_dim)
)
kwargs[self._vision_encoder_namespace] = {
VisionKwargs.sequence_first: sequence_first,
Expand Down
24 changes: 21 additions & 3 deletions fast_llm_external_models/apriel2/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def _reorder_cache_obj(self, cache, beam_idx):
cache.value = cache.value.index_select(0, beam_idx.to(cache.value.device))
elif isinstance(cache, _SSMCache):
if cache.conv is not None:
cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device))
# Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
if isinstance(cache.conv, tuple):
cache.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in cache.conv)
else:
cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device))
if cache.recurrent is not None:
cache.recurrent = cache.recurrent.index_select(0, beam_idx.to(cache.recurrent.device))

Expand Down Expand Up @@ -208,7 +212,11 @@ def _batch_repeat_cache_obj(self, cache, repeats):
cache.value = cache.value.repeat_interleave(repeats, dim=0)
elif isinstance(cache, _SSMCache):
if cache.conv is not None:
cache.conv = cache.conv.repeat_interleave(repeats, dim=0)
# Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
if isinstance(cache.conv, tuple):
cache.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in cache.conv)
else:
cache.conv = cache.conv.repeat_interleave(repeats, dim=0)
if cache.recurrent is not None:
cache.recurrent = cache.recurrent.repeat_interleave(repeats, dim=0)

Expand All @@ -227,7 +235,11 @@ def _batch_select_cache_obj(self, cache, indices):
cache.value = cache.value.index_select(0, indices.to(cache.value.device))
elif isinstance(cache, _SSMCache):
if cache.conv is not None:
cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device))
# Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
if isinstance(cache.conv, tuple):
cache.conv = tuple(c.index_select(0, indices.to(c.device)) for c in cache.conv)
else:
cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device))
if cache.recurrent is not None:
cache.recurrent = cache.recurrent.index_select(0, indices.to(cache.recurrent.device))

Expand Down Expand Up @@ -274,11 +286,17 @@ def max_batch_size(self):
if isinstance(cache, _AttentionCache) and cache.key is not None:
return cache.key.shape[0]
if isinstance(cache, _SSMCache) and cache.conv is not None:
# Handle both single tensor and tuple conv states
if isinstance(cache.conv, tuple):
return cache.conv[0].shape[0]
return cache.conv.shape[0]
else:
if isinstance(layer, _AttentionCache) and layer.key is not None:
return layer.key.shape[0]
if isinstance(layer, _SSMCache) and layer.conv is not None:
# Handle both single tensor and tuple conv states
if isinstance(layer.conv, tuple):
return layer.conv[0].shape[0]
return layer.conv.shape[0]
return None

Expand Down
6 changes: 4 additions & 2 deletions fast_llm_external_models/apriel2/conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@

# Plan builders (generic)
from fast_llm_external_models.apriel2.conversion.converters import (
plan_attention_to_gated_delta_net,
plan_mil_attention_to_mamba,
plan_dil_attention_to_gdn,
plan_kil_attention_to_kda,
plan_surgery,
)

Expand Down Expand Up @@ -170,7 +171,8 @@
# Plan builders (generic)
"plan_surgery",
"plan_mil_attention_to_mamba",
"plan_attention_to_gated_delta_net",
"plan_dil_attention_to_gdn",
"plan_kil_attention_to_kda",
# Config composition
"compose_configs",
# Source-specific converters
Expand Down
17 changes: 17 additions & 0 deletions fast_llm_external_models/apriel2/conversion/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
- attention → sliding_window: preserve heads, head_groups, head_size
- attention → gdn: heads → value_heads, head_groups → key_heads
- attention → mamba: derive d_inner, d_xb, dt_rank from hidden_size
- attention → kda: preserve heads, head_size → head_dim

**Stochastic Mixer Composition**
Two semantics based on whether surgery declares `type: stochastic`:
Expand Down Expand Up @@ -439,6 +440,22 @@ def _compose_single_mixer(source: dict, surgery: dict, hidden_size: int) -> dict
result["init"] = surgery["init"]
return result

elif target_type == "kda":
# Attention → KDA: derive heads/head_dim from attention geometry
result = {
"type": "kda",
"heads": surgery.get("heads", heads),
"head_dim": surgery.get("head_dim", head_size),
}
# Copy KDA-specific fields from surgery
for key in ["convolution_layer", "normalization"]:
if key in surgery:
result[key] = surgery[key]
# Preserve init
if "init" in surgery:
result["init"] = surgery["init"]
return result

# Fallback: start fresh with surgery, no inheritance
result = copy.deepcopy(surgery)
result["type"] = target_type
Expand Down
Loading