diff --git a/tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py b/tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py new file mode 100644 index 000000000..d1e2348df --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py @@ -0,0 +1,771 @@ +"""Unit tests for BaichuanArchitectureAdapter. + +Tests cover: +- Config attributes +- Component mapping structure and HF module names +- Weight conversion keys/types +- split_qkv_matrix (W_pack) numerical correctness +- preprocess_weights (QKV split, fold_ln, NormHead normalization) +- Factory registration (both v1 and v2 class names) +""" + +from types import SimpleNamespace +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + JointQKVPositionEmbeddingsAttentionBridge, + RMSNormalizationBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.baichuan import ( + BaichuanArchitectureAdapter, + _BaichuanAttentionBridge, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 32, + d_model: int = 64, + n_layers: int = 2, + d_vocab: int = 100, + n_ctx: int = 128, +) -> TransformerBridgeConfig: + """Minimal TransformerBridgeConfig for Baichuan adapter tests.""" + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + default_prepend_bos=True, + architecture="BaichuanForCausalLM", + ) + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg(n_heads=8, d_model=64) + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> BaichuanArchitectureAdapter: + return BaichuanArchitectureAdapter(cfg) + + +def _make_w_pack_component(d_model: int) -> Any: + """Synthetic attention namespace with W_pack linear.""" + ns = SimpleNamespace() + ns.W_pack = nn.Linear(d_model, 3 * d_model, bias=False) + return ns + + +# --------------------------------------------------------------------------- +# Config attribute tests +# --------------------------------------------------------------------------- + + +class TestBaichuanAdapterConfig: + def test_normalization_type(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is True + + def test_gated_mlp(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is True + + def test_attn_only(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_uses_rms_norm(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_eps_attr(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.eps_attr == "variance_epsilon" + + def test_supports_fold_ln_false(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.supports_fold_ln is False + + +# --------------------------------------------------------------------------- +# Component mapping tests +# --------------------------------------------------------------------------- + + +class TestBaichuanAdapterComponentMapping: + @staticmethod + def _mapping(adapter: BaichuanArchitectureAdapter) -> dict[str, Any]: + mapping = adapter.component_mapping + assert mapping is not None + return mapping + + def test_embed_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["embed"], EmbeddingBridge) + assert mapping["embed"].name == "model.embed_tokens" + + def test_no_top_level_rotary_emb(self, adapter: BaichuanArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert "rotary_emb" not in mapping + + def test_blocks_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["blocks"], BlockBridge) + assert mapping["blocks"].name == "model.layers" + + def test_ln_final_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["ln_final"], RMSNormalizationBridge) + assert mapping["ln_final"].name == "model.norm" + + def test_unembed_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["unembed"], UnembeddingBridge) + assert mapping["unembed"].name == "lm_head" + + def test_ln1_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge) + assert blocks.submodules["ln1"].name == "input_layernorm" + + def test_ln2_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge) + assert blocks.submodules["ln2"].name == "post_attention_layernorm" + + def test_attn_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["attn"], JointQKVPositionEmbeddingsAttentionBridge) + assert blocks.submodules["attn"].name == "self_attn" + + def test_attn_qkv_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["attn"].submodules["qkv"].name == "W_pack" + + def test_attn_o_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["attn"].submodules["o"].name == "o_proj" + + def test_mlp_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["mlp"], GatedMLPBridge) + assert blocks.submodules["mlp"].name == "mlp" + + def test_mlp_gate_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["mlp"].submodules["gate"].name == "gate_proj" + + def test_mlp_in_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["mlp"].submodules["in"].name == "up_proj" + + def test_mlp_out_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["mlp"].submodules["out"].name == "down_proj" + + +# --------------------------------------------------------------------------- +# Weight conversion tests +# --------------------------------------------------------------------------- + + +class TestBaichuanAdapterWeightConversions: + def test_four_conversion_keys(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + assert len(convs) == 4 + + def test_qkvo_keys_present(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + for key in [ + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.o.weight", + ]: + assert key in convs + + def test_q_conversion_type(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + def test_q_rearrange_pattern(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_q_rearrange_n_equals_n_heads(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_k_rearrange_n_equals_n_heads(self, adapter: BaichuanArchitectureAdapter) -> None: + # Baichuan is MHA (no GQA), so K also uses n_heads + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.k.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_o_rearrange_pattern(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_no_source_key_on_q(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert conv.source_key is None + + +# --------------------------------------------------------------------------- +# split_qkv_matrix (W_pack) tests +# --------------------------------------------------------------------------- + + +class TestBaichuanSplitWPack: + def _adapter(self, n_heads: int = 8, d_model: int = 64) -> BaichuanArchitectureAdapter: + return BaichuanArchitectureAdapter(_make_cfg(n_heads=n_heads, d_model=d_model)) + + def test_returns_three_linears(self) -> None: + adapter = self._adapter() + attn = _make_w_pack_component(64) + q, k, v = adapter._split_baichuan_w_pack(attn) + assert isinstance(q, nn.Linear) + assert isinstance(k, nn.Linear) + assert isinstance(v, nn.Linear) + + def test_output_shapes(self) -> None: + d_model = 64 + adapter = self._adapter(d_model=d_model) + attn = _make_w_pack_component(d_model) + q, k, v = adapter._split_baichuan_w_pack(attn) + assert q.weight.shape == (d_model, d_model) + assert k.weight.shape == (d_model, d_model) + assert v.weight.shape == (d_model, d_model) + + def test_no_bias(self) -> None: + adapter = self._adapter() + attn = _make_w_pack_component(64) + q, k, v = adapter._split_baichuan_w_pack(attn) + assert q.bias is None + assert k.bias is None + assert v.bias is None + + def test_concatenated_split_correctness(self) -> None: + """W_pack = [Q|K|V] concatenated — verify split recovers each part.""" + d_model = 32 + adapter = self._adapter(n_heads=4, d_model=d_model) + attn = _make_w_pack_component(d_model) + # Fill W_pack: Q=1.0, K=2.0, V=3.0 + w = torch.zeros(3 * d_model, d_model) + w[:d_model, :] = 1.0 + w[d_model : 2 * d_model, :] = 2.0 + w[2 * d_model :, :] = 3.0 + attn.W_pack.weight = nn.Parameter(w) + + q, k, v = adapter._split_baichuan_w_pack(attn) + assert torch.all(q.weight == 1.0), "Q should be 1.0" + assert torch.all(k.weight == 2.0), "K should be 2.0" + assert torch.all(v.weight == 3.0), "V should be 3.0" + + def test_round_trip_recombine(self) -> None: + """Split → recombine must equal original W_pack weights.""" + d_model = 64 + adapter = self._adapter(d_model=d_model) + attn = _make_w_pack_component(d_model) + original_w = attn.W_pack.weight.data.clone() + + q, k, v = adapter._split_baichuan_w_pack(attn) + recombined = torch.cat([q.weight.data, k.weight.data, v.weight.data], dim=0) + assert torch.equal(recombined, original_w) + + def test_forward_output_shapes(self) -> None: + d_model = 64 + adapter = self._adapter(d_model=d_model) + attn = _make_w_pack_component(d_model) + q, k, v = adapter._split_baichuan_w_pack(attn) + x = torch.randn(2, 5, d_model) + assert q(x).shape == (2, 5, d_model) + assert k(x).shape == (2, 5, d_model) + assert v(x).shape == (2, 5, d_model) + + +# --------------------------------------------------------------------------- +# preprocess_weights tests +# --------------------------------------------------------------------------- + + +class TestBaichuanPreprocessWeights: + def _make_state_dict( + self, + adapter: BaichuanArchitectureAdapter, + d_model: int = 64, + n_layers: int = 2, + d_mlp: int = 16, + d_vocab: int = 100, + ln1_scale: float = 1.0, + qkv_val: float = 1.0, + ) -> dict[str, torch.Tensor]: + """Bridge-format state dict with fused W_pack for each layer.""" + state: dict[str, torch.Tensor] = {} + for i in range(n_layers): + state[f"blocks.{i}.attn.qkv.weight"] = torch.full((3 * d_model, d_model), qkv_val) + state[f"blocks.{i}.ln1.weight"] = torch.full((d_model,), ln1_scale) + state[f"blocks.{i}.ln2.weight"] = torch.ones(d_model) + state[f"blocks.{i}.mlp.gate.weight"] = torch.ones(d_mlp, d_model) + state[f"blocks.{i}.mlp.in.weight"] = torch.ones(d_mlp, d_model) + state[f"blocks.{i}.attn.o.weight"] = torch.ones(d_model, d_model) + state["ln_final.weight"] = torch.ones(d_model) + state["unembed.weight"] = torch.ones(d_vocab, d_model) + return state + + def test_fused_key_removed_and_split_keys_written(self) -> None: + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter) + result = adapter.preprocess_weights(sd) + assert "blocks.0.attn.qkv.weight" not in result + assert "blocks.0.attn.q.weight" in result + assert "blocks.0.attn.k.weight" in result + assert "blocks.0.attn.v.weight" in result + + def test_split_shapes(self) -> None: + d_model = 64 + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=d_model)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, d_model=d_model) + result = adapter.preprocess_weights(sd) + # Baichuan is MHA: Q, K, V each have shape [d_model, d_model] + assert result["blocks.0.attn.q.weight"].shape == (d_model, d_model) + assert result["blocks.0.attn.k.weight"].shape == (d_model, d_model) + assert result["blocks.0.attn.v.weight"].shape == (d_model, d_model) + + def test_ln1_fold_applied(self) -> None: + d_model = 64 + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=d_model)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, d_model=d_model, ln1_scale=2.0, qkv_val=1.0) + result = adapter.preprocess_weights(sd) + assert torch.all(result["blocks.0.attn.q.weight"] == 2.0) + assert torch.all(result["blocks.0.attn.k.weight"] == 2.0) + assert torch.all(result["blocks.0.attn.v.weight"] == 2.0) + + def test_ln1_reset_to_ones(self) -> None: + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, ln1_scale=3.0) + result = adapter.preprocess_weights(sd) + assert torch.all(result["blocks.0.ln1.weight"] == 1.0) + + def test_ln2_fold_applied(self) -> None: + d_model = 64 + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=d_model)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, d_model=d_model) + sd["blocks.0.ln2.weight"] = torch.full((d_model,), 3.0) + result = adapter.preprocess_weights(sd) + assert torch.all(result["blocks.0.mlp.gate.weight"] == 3.0) + assert torch.all(result["blocks.0.mlp.in.weight"] == 3.0) + + def test_no_fold_still_splits_qkv(self) -> None: + """Without fold_ln, W_pack must still be split for weight conversions.""" + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + adapter._fold_ln_requested = False + sd = self._make_state_dict(adapter) + result = adapter.preprocess_weights(sd) + assert "blocks.0.attn.qkv.weight" not in result + assert "blocks.0.attn.q.weight" in result + assert "blocks.0.attn.k.weight" in result + assert "blocks.0.attn.v.weight" in result + + def test_ln_final_fold_values(self) -> None: + """ln_final fold multiplies unembed weights by ln_final scale.""" + d_model = 64 + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=d_model)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, d_model=d_model) + sd["ln_final.weight"] = torch.full((d_model,), 2.0) + sd["unembed.weight"] = torch.ones(100, d_model) + result = adapter.preprocess_weights(sd) + assert torch.all(result["unembed.weight"] == 2.0) + assert torch.all(result["ln_final.weight"] == 1.0) + + def test_dtype_preserved(self) -> None: + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter) + sd = {k: v.to(torch.bfloat16) for k, v in sd.items()} + result = adapter.preprocess_weights(sd) + assert result["blocks.0.attn.q.weight"].dtype == torch.bfloat16 + + def test_all_layers_processed(self) -> None: + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64, n_layers=3)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, n_layers=3) + result = adapter.preprocess_weights(sd) + for i in range(3): + assert f"blocks.{i}.attn.qkv.weight" not in result + assert f"blocks.{i}.attn.q.weight" in result + + +# --------------------------------------------------------------------------- +# prepare_model tests (NormHead normalization) +# --------------------------------------------------------------------------- + + +class TestBaichuanPrepareModel: + def _adapter(self) -> BaichuanArchitectureAdapter: + return BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + + def test_normhead_weights_normalized(self) -> None: + """NormHead (has first_flag) should have row-normalized weights after prepare_model.""" + adapter = self._adapter() + lm_head = SimpleNamespace( + weight=nn.Parameter(torch.full((100, 64), 2.0)), + first_flag=True, + ) + hf_model = SimpleNamespace(lm_head=lm_head) + adapter.prepare_model(hf_model) + row_norms = lm_head.weight.data.float().norm(dim=-1) + assert torch.allclose(row_norms, torch.ones_like(row_norms), atol=1e-5) + + def test_regular_linear_unchanged(self) -> None: + """nn.Linear lm_head (no first_flag) should not be modified.""" + adapter = self._adapter() + lm_head = nn.Linear(64, 100, bias=False) + original_w = lm_head.weight.data.clone() + hf_model = SimpleNamespace(lm_head=lm_head) + adapter.prepare_model(hf_model) + assert torch.equal(lm_head.weight.data, original_w) + + def test_no_lm_head_is_noop(self) -> None: + """Model without lm_head should not raise.""" + adapter = self._adapter() + hf_model = SimpleNamespace() + adapter.prepare_model(hf_model) # should not raise + + def test_recomputes_rotary_from_scratch_when_inv_freq_is_meta(self) -> None: + """Baichuan2's inv_freq/cos_cached are plain attrs that land on meta under + HF v5 meta-init; prepare_model must recompute real values regardless.""" + adapter = self._adapter() + head_dim = adapter.cfg.d_model // adapter.cfg.n_heads + # Meta-device rotary matching v2's plain-attribute shape + rotary = SimpleNamespace( + inv_freq=torch.empty(head_dim // 2, device="meta"), + cos_cached=torch.empty(1, 1, 16, head_dim, device="meta"), + sin_cached=torch.empty(1, 1, 16, head_dim, device="meta"), + max_seq_len_cached=16, + ) + layer = SimpleNamespace(self_attn=SimpleNamespace(rotary_emb=rotary)) + hf_model = SimpleNamespace(model=SimpleNamespace(layers=[layer])) + + adapter.prepare_model(hf_model) + + assert rotary.inv_freq.device.type == "cpu" + assert rotary.cos_cached.device.type == "cpu" + assert rotary.sin_cached.device.type == "cpu" + assert rotary.cos_cached.shape == (1, 1, 16, head_dim) + # Sanity: cos(0) == 1 and position 0 of each head_dim element equals 1. + assert torch.allclose( + rotary.cos_cached[0, 0, 0, :], + torch.ones(head_dim), + atol=1e-6, + ) + + +# --------------------------------------------------------------------------- +# Factory registration tests +# --------------------------------------------------------------------------- + + +class TestBaichuanFactoryRegistration: + def test_factory_v2_key(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "BaichuanForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_factory_v1_key(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "BaiChuanForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_factory_v2_returns_baichuan_adapter(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg(n_heads=8, d_model=64) + cfg.architecture = "BaichuanForCausalLM" + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, BaichuanArchitectureAdapter) + + def test_factory_v1_returns_baichuan_adapter(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg(n_heads=8, d_model=64) + cfg.architecture = "BaiChuanForCausalLM" + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, BaichuanArchitectureAdapter) + + def test_import_from_init(self) -> None: + from transformer_lens.model_bridge.supported_architectures import ( + BaichuanArchitectureAdapter as FromInit, + ) + + assert FromInit is BaichuanArchitectureAdapter + + +# --------------------------------------------------------------------------- +# Attention bridge: position_ids → position_embeddings conversion +# --------------------------------------------------------------------------- + + +class _FakeRotary(nn.Module): + """Minimal stand-in for Baichuan's RotaryEmbedding (returns 4D cached cos/sin).""" + + def __init__(self, head_dim: int, max_seq_len: int) -> None: + super().__init__() + self.max_seq_len_cached = max_seq_len + # Fill with position-dependent values so tests can verify indexing. + cos = ( + torch.arange(max_seq_len, dtype=torch.float32)[:, None] + .expand(max_seq_len, head_dim) + .clone() + ) + sin = -cos + self.register_buffer("cos_cached", cos[None, None, :, :]) + self.register_buffer("sin_cached", sin[None, None, :, :]) + self.calls: list[int] = [] + + def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: + self.calls.append(seq_len) + cos_cached = self.cos_cached + sin_cached = self.sin_cached + assert isinstance(cos_cached, torch.Tensor) + assert isinstance(sin_cached, torch.Tensor) + return ( + cos_cached[:, :, :seq_len, :].to(dtype=x.dtype), + sin_cached[:, :, :seq_len, :].to(dtype=x.dtype), + ) + + +class _FakeAttention(nn.Module): + """nn.Module container that exposes a `rotary_emb` + `o_proj` to the bridge.""" + + def __init__(self, rotary: _FakeRotary, d_model: int) -> None: + super().__init__() + self.rotary_emb = rotary + self.o_proj = nn.Linear(d_model, d_model, bias=False) + nn.init.zeros_(self.o_proj.weight) + + +def _make_attention_bridge(cfg: TransformerBridgeConfig) -> _BaichuanAttentionBridge: + from transformer_lens.model_bridge.generalized_components import LinearBridge + + return _BaichuanAttentionBridge( + name="self_attn", + config=cfg, + split_qkv_matrix=lambda _c: ( + nn.Linear(cfg.d_model, cfg.d_model, bias=False), + nn.Linear(cfg.d_model, cfg.d_model, bias=False), + nn.Linear(cfg.d_model, cfg.d_model, bias=False), + ), + submodules={ + "qkv": LinearBridge(name="W_pack"), + "o": LinearBridge(name="o_proj"), + }, + ) + + +def _wire_bridge( + cfg: TransformerBridgeConfig, +) -> tuple[_BaichuanAttentionBridge, _FakeRotary, int]: + """Build a bridge with a fake HF attention (rotary + o_proj) attached.""" + head_dim = cfg.d_model // cfg.n_heads + bridge = _make_attention_bridge(cfg) + rotary = _FakeRotary(head_dim=head_dim, max_seq_len=32) + fake_attn = _FakeAttention(rotary, cfg.d_model) + bridge.set_original_component(fake_attn) + # `o` LinearBridge is normally wired by setup_components via component_mapping; + # wire it directly for unit tests that construct the bridge standalone. + bridge.o.set_original_component(fake_attn.o_proj) + return bridge, rotary, head_dim + + +class TestBaichuanAttentionBridgeRotary: + """Regression tests for the attention bridge's rotary + KV-cache contract.""" + + def test_uses_position_ids_when_position_embeddings_absent( + self, cfg: TransformerBridgeConfig + ) -> None: + bridge, rotary, head_dim = _wire_bridge(cfg) + + batch, seq = 1, 4 + q = torch.zeros(batch, seq, cfg.d_model) + k = torch.zeros_like(q) + v = torch.zeros_like(q) + position_ids = torch.tensor([[0, 1, 2, 3]]) + + attn_output, _, present = bridge._reconstruct_attention( + q, k, v, position_ids=position_ids, use_cache=True + ) + + # rotary_emb called once, with kv_seq_len=seq (no past) + assert rotary.calls == [seq] + assert attn_output.shape == (batch, seq, cfg.d_model) + assert present is not None + present_k, present_v = present + assert present_k.shape == (batch, cfg.n_heads, seq, head_dim) + assert present_v.shape == (batch, cfg.n_heads, seq, head_dim) + + def test_preserves_explicit_position_embeddings(self, cfg: TransformerBridgeConfig) -> None: + bridge, rotary, head_dim = _wire_bridge(cfg) + + batch, seq = 1, 4 + q = torch.zeros(batch, seq, cfg.d_model) + k = torch.zeros_like(q) + v = torch.zeros_like(q) + explicit = ( + torch.ones(batch, seq, head_dim) * 7, + torch.ones(batch, seq, head_dim) * 9, + ) + + bridge._reconstruct_attention( + q, + k, + v, + position_embeddings=explicit, + position_ids=torch.tensor([[0, 1, 2, 3]]), + use_cache=True, + ) + # Caller-supplied embeddings must win; rotary_emb must not be called. + assert rotary.calls == [] + + def test_use_cache_false_returns_none_present(self, cfg: TransformerBridgeConfig) -> None: + bridge, _, _ = _wire_bridge(cfg) + q = torch.zeros(1, 4, cfg.d_model) + _, _, present = bridge._reconstruct_attention( + q, q.clone(), q.clone(), position_ids=torch.tensor([[0, 1, 2, 3]]) + ) + assert present is None + + def test_concats_past_key_value_along_seq_dim(self, cfg: TransformerBridgeConfig) -> None: + """With past cache of length P and current seq S, the present cache's + k/v have seq dim P+S and rotary is requested with kv_seq_len=P+S.""" + bridge, rotary, head_dim = _wire_bridge(cfg) + + batch, past_len, seq = 1, 3, 2 + past_k = torch.randn(batch, cfg.n_heads, past_len, head_dim) + past_v = torch.randn(batch, cfg.n_heads, past_len, head_dim) + + q = torch.zeros(batch, seq, cfg.d_model) + k = torch.zeros_like(q) + v = torch.zeros_like(q) + # HF's Model.forward generates position_ids offset by past_len. + position_ids = torch.tensor([[past_len, past_len + 1]]) + + _, _, present = bridge._reconstruct_attention( + q, + k, + v, + past_key_value=(past_k, past_v), + position_ids=position_ids, + use_cache=True, + ) + assert rotary.calls == [past_len + seq] + assert present is not None + present_k, present_v = present + assert present_k.shape == (batch, cfg.n_heads, past_len + seq, head_dim) + assert present_v.shape == (batch, cfg.n_heads, past_len + seq, head_dim) + # First past_len slots must be the provided past, unchanged. + assert torch.equal(present_k[:, :, :past_len, :], past_k) + assert torch.equal(present_v[:, :, :past_len, :], past_v) + + +# --------------------------------------------------------------------------- +# prepare_loading: bitsandbytes preflight +# --------------------------------------------------------------------------- + + +class TestBaichuanPrepareLoadingBitsandbytes: + """The adapter must point users at `uv sync --group quantization` when bnb is missing.""" + + def test_preflight_raises_clean_import_error( + self, adapter: BaichuanArchitectureAdapter, monkeypatch: pytest.MonkeyPatch + ) -> None: + import transformer_lens.model_bridge.supported_architectures.baichuan as baichuan_mod + + # Force the preflight path: make find_spec report bitsandbytes missing, + # and make get_class_from_dynamic_module surface the transformers-style + # "requires the following packages... bitsandbytes" error. + monkeypatch.setattr(baichuan_mod.importlib.util, "find_spec", lambda name: None) + + def _raise_bnb(*_a: Any, **_k: Any) -> None: + raise ImportError( + "This modeling file requires the following packages that were " + "not found in your environment: bitsandbytes" + ) + + import transformers.dynamic_module_utils as dmu + + monkeypatch.setattr(dmu, "get_class_from_dynamic_module", _raise_bnb) + + with pytest.raises(ImportError, match="uv sync --group quantization"): + adapter.prepare_loading("baichuan-inc/Baichuan2-7B-Chat", {}) + + def test_preflight_no_false_positive_when_bnb_installed( + self, adapter: BaichuanArchitectureAdapter, monkeypatch: pytest.MonkeyPatch + ) -> None: + """If bnb IS installed, the transformers error won't mention bnb, so no raise.""" + import transformers.dynamic_module_utils as dmu + + def _raise_generic(*_a: Any, **_k: Any) -> None: + raise ValueError("some unrelated loader failure") + + monkeypatch.setattr(dmu, "get_class_from_dynamic_module", _raise_generic) + # Must not raise — the generic failure path is swallowed (remote load + # may legitimately fail for offline tests, e.g. no network access). + adapter.prepare_loading("baichuan-inc/Baichuan2-7B-Chat", {}) diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index a55f51a5a..b5432aff1 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -7,6 +7,7 @@ from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.supported_architectures import ( ApertusArchitectureAdapter, + BaichuanArchitectureAdapter, BertArchitectureAdapter, BloomArchitectureAdapter, CodeGenArchitectureAdapter, @@ -63,6 +64,8 @@ # Export supported architectures SUPPORTED_ARCHITECTURES = { "ApertusForCausalLM": ApertusArchitectureAdapter, + "BaiChuanForCausalLM": BaichuanArchitectureAdapter, + "BaichuanForCausalLM": BaichuanArchitectureAdapter, "BertForMaskedLM": BertArchitectureAdapter, "BloomForCausalLM": BloomArchitectureAdapter, "CodeGenForCausalLM": CodeGenArchitectureAdapter, diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 7f990e393..772d76942 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -6,6 +6,9 @@ from transformer_lens.model_bridge.supported_architectures.apertus import ( ApertusArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.baichuan import ( + BaichuanArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.bert import ( BertArchitectureAdapter, ) @@ -165,6 +168,7 @@ __all__ = [ "ApertusArchitectureAdapter", + "BaichuanArchitectureAdapter", "BertArchitectureAdapter", "BloomArchitectureAdapter", "CodeGenArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/baichuan.py b/transformer_lens/model_bridge/supported_architectures/baichuan.py new file mode 100644 index 000000000..a50fabc37 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/baichuan.py @@ -0,0 +1,447 @@ +"""Baichuan architecture adapter. + +Supports both BaiChuanForCausalLM (v1) and BaichuanForCausalLM (v2). +Both use combined QKV via W_pack with RoPE, RMSNorm, and gated MLP. +""" + +import importlib.util +import sys +from typing import Any + +import torch +import torch.nn as nn + +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5 +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + JointQKVPositionEmbeddingsAttentionBridge, + LinearBridge, + RMSNormalizationBridge, + UnembeddingBridge, +) + + +class _BaichuanAttentionBridge(JointQKVPositionEmbeddingsAttentionBridge): + """Attention bridge for Baichuan's v4-era decoder-layer contract. + + Baichuan predates HF's Cache API and differs from the base bridge in two + ways we have to own: + + 1. **Rotary from position_ids**: HF passes `position_ids` (not a + pre-computed `position_embeddings` tuple), so we call the per-layer + `rotary_emb(v, seq_len=kv_seq_len)` ourselves and slice cos/sin by + `position_ids`. + 2. **Legacy (k, v) cache tuple**: HF's DecoderLayer passes + `past_key_value=(k, v)` (singular, per-layer legacy tuple) and expects + `self_attn(...)` to return a matching `(k_full, v_full)` as + `present_key_value` so Model.forward's `next_decoder_cache` accumulates + real tensors. The base bridge's `_update_kv_cache` only handles the + Cache-object plural path, so we reimplement the attention body here + (mirroring HF's own Attention.forward). + """ + + def _reconstruct_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs + ) -> tuple: + assert self.original_component is not None + assert self.config is not None + num_heads = self.config.n_heads + num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads + + q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads( + q, k, v, num_heads, num_kv_heads + ) + + past_kv_raw = kwargs.get("past_key_value") + past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None + if ( + isinstance(past_kv_raw, tuple) + and len(past_kv_raw) >= 2 + and isinstance(past_kv_raw[0], torch.Tensor) + and isinstance(past_kv_raw[1], torch.Tensor) + ): + past_key_value = (past_kv_raw[0], past_kv_raw[1]) + past_len = past_key_value[0].shape[-2] if past_key_value is not None else 0 + + # Rotary: derive cos/sin over the full kv_seq_len, index by position_ids. + if "position_embeddings" not in kwargs: + rotary_emb = getattr(self.original_component, "rotary_emb", None) + position_ids = kwargs.get("position_ids") + if rotary_emb is not None and position_ids is not None: + kv_seq_len = seq_len + past_len + cos, sin = rotary_emb(v, seq_len=kv_seq_len) + cos = cos.squeeze(1).squeeze(0)[position_ids] + sin = sin.squeeze(1).squeeze(0)[position_ids] + kwargs["position_embeddings"] = (cos, sin) + + position_embeddings = kwargs.get("position_embeddings") + if position_embeddings is not None and isinstance(position_embeddings, tuple): + cos, sin = self._apply_position_embedding_hooks(position_embeddings) + q, k = self._apply_rotary_pos_emb(q, k, cos, sin) + + # Concat prior (k, v) — already rotary-applied from its own step. + if past_key_value is not None: + k = torch.cat([past_key_value[0], k], dim=-2) + v = torch.cat([past_key_value[1], v], dim=-2) + + # Build present cache from pre-GQA-expansion (k, v) so downstream + # steps don't pay for duplicated heads. + use_cache = bool(kwargs.get("use_cache", False)) + present_key_value = (k, v) if use_cache else None + + if num_kv_heads != num_heads: + n_rep = num_heads // num_kv_heads + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + + kv_seq_len = k.shape[-2] + attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** (-0.5)) + attention_mask = kwargs.get("attention_mask", None) + attn_scores = self._apply_reconstruct_attention_mask( + attn_scores=attn_scores, + attention_mask=attention_mask, + seq_len=kv_seq_len, + q_seq_len=seq_len, + ) + attn_scores = self.hook_attn_scores(attn_scores) + attn_weights = self._softmax_dropout_pattern(attn_scores) + attn_output = torch.matmul(attn_weights, v) + attn_output = self._reshape_attn_output( + attn_output, batch_size, seq_len, num_heads, head_dim + ) + if ( + bool(getattr(self.config, "use_attn_result", False)) + and hasattr(self, "o") + and self.o.original_component is not None + ): + attn_output = self.o.hook_in(attn_output) + z_4d = attn_output.view(batch_size, seq_len, num_heads, head_dim) + attn_output = self._compute_per_head_result(z_4d, num_heads, head_dim) + else: + attn_output = self._apply_output_projection(attn_output) + + return (attn_output, attn_weights, present_key_value) + + +def _patch_init_weights_for_baichuan() -> None: + """Prevent _init_weights from re-randomizing loaded checkpoint weights. + + Transformers v5 calls _init_weights on all modules after weight + materialization. For modules with real (non-meta) tensors, we must + skip re-initialization to preserve the loaded checkpoint values. + """ + for key in list(sys.modules.keys()): + if "baichuan" not in key.lower() or "modeling" not in key.lower(): + continue + module = sys.modules[key] + # Both v1 (BaiChuan) and v2 (Baichuan) define a PreTrainedModel subclass + for cls_name in ("BaiChuanPreTrainedModel", "BaichuanPreTrainedModel", "PreTrainedModel"): + pretrained_cls = getattr(module, cls_name, None) + if pretrained_cls is None or getattr(pretrained_cls, "_tl_patched", False): + continue + # Only patch classes that define their own _init_weights + if "_init_weights" not in pretrained_cls.__dict__: + continue + + original_init_weights = pretrained_cls._init_weights + + def safe_init_weights(self, mod, _original=original_init_weights): # type: ignore[no-untyped-def] + first_param = next(mod.parameters(), None) + if first_param is not None and first_param.device.type != "meta": + return + _original(self, mod) + + pretrained_cls._init_weights = safe_init_weights + pretrained_cls._tl_patched = True + + +class BaichuanArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for Baichuan models (v1 and v2). + + Baichuan uses combined QKV via W_pack (nn.Linear(h, 3*h)) with RoPE, + RMSNorm, and gated MLP (SwiGLU). Per-layer rotary embeddings. + + Optional Parameters (may not exist in state_dict): + ------------------------------------------------- + Baichuan models do NOT have biases on any projection: + + - blocks.{i}.attn.b_Q / b_K / b_V / b_O — no bias + - blocks.{i}.mlp.b_gate / b_in / b_out — no bias + - blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias + """ + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.normalization_type = "RMS" + self.cfg.positional_embedding_type = "rotary" + self.cfg.final_rms = True + self.cfg.gated_mlp = True + self.cfg.attn_only = False + self.cfg.uses_rms_norm = True + self.cfg.eps_attr = "variance_epsilon" + + # Fused W_pack prevents standard fold_ln from reaching Q/K/V separately. + # preprocess_weights() handles it instead. + self.supports_fold_ln = False + + self.weight_processing_conversions = { + "blocks.{i}.attn.q.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), + ), + "blocks.{i}.attn.k.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), + ), + "blocks.{i}.attn.v.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), + ), + "blocks.{i}.attn.o.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=cfg.n_heads), + ), + } + + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), + "attn": _BaichuanAttentionBridge( + name="self_attn", + config=self.cfg, + split_qkv_matrix=self._split_baichuan_w_pack, + submodules={ + "qkv": LinearBridge(name="W_pack"), + "o": LinearBridge(name="o_proj"), + }, + ), + "mlp": GatedMLPBridge( + name="mlp", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), + } + + def _split_baichuan_w_pack( + self, attention_component: Any + ) -> tuple[nn.Linear, nn.Linear, nn.Linear]: + """Split Baichuan's W_pack into separate Q, K, V linear modules. + + W_pack is a simple concatenation: [Q | K | V], each of size hidden_size. + No interleaving, no GQA — all three chunks are equal size. + """ + w_pack = attention_component.W_pack + weight = w_pack.weight.data + d_model = weight.shape[1] + hidden_size = d_model # Q, K, V each have hidden_size output features + + q_w = weight[:hidden_size, :] + k_w = weight[hidden_size : 2 * hidden_size, :] + v_w = weight[2 * hidden_size :, :] + + def _make_linear(w: torch.Tensor) -> nn.Linear: + lin = nn.Linear(d_model, hidden_size, bias=False) + lin.weight = nn.Parameter(w) + return lin + + return _make_linear(q_w), _make_linear(k_w), _make_linear(v_w) + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Inject per-layer rotary embedding for component testing.""" + try: + rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb + except (AttributeError, IndexError): + return + + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb) + + def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: + """Patch transformers v5 incompatibilities before from_pretrained runs.""" + patch_dynamic_cache_v5() + + # Force-import the remote modeling module so we can patch _init_weights. + # Baichuan2 variants ship quantizer.py which imports bitsandbytes; + # transformers' check_imports scans every .py file in the repo and + # raises ImportError if bitsandbytes is missing, even though quantizer + # is not used in normal inference. Catch that case and tell the user + # how to install the optional dependency group. + try: + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + last_exc: Exception | None = None + # Try both class names (v1 and v2) + for cls_name in ( + "modeling_baichuan.BaichuanForCausalLM", + "modeling_baichuan.BaiChuanForCausalLM", + ): + try: + get_class_from_dynamic_module(cls_name, model_name) + last_exc = None + break + except Exception as exc: + last_exc = exc + continue + if last_exc is not None and "bitsandbytes" in str(last_exc): + if importlib.util.find_spec("bitsandbytes") is None: + raise ImportError( + "Baichuan2 variants require `bitsandbytes` for " + "trust_remote_code loading (their shipped quantizer.py " + "imports it). Install the quantization extras: " + "`uv sync --group quantization`." + ) from last_exc + except ImportError: + raise + except Exception: + pass + + _patch_init_weights_for_baichuan() + + def prepare_model(self, hf_model: Any) -> None: + """Fix rotary caches and normalize NormHead weights before bridge creation. + + RotaryEmbedding differs between v1 and v2: + - v1 (Baichuan-7B): `inv_freq` is a persistent buffer, loaded from the + checkpoint as bfloat16, but `cos_cached`/`sin_cached` are non-persistent + and materialize as garbage under meta-init. + - v2 (Baichuan2-*): `inv_freq`, `cos_cached`, `sin_cached` are all plain + attributes (no `register_buffer`). v5's meta-init materializes them on + meta, and nothing in the checkpoint overwrites them. + + Both cases are resolved by computing inv_freq + caches from scratch at + float32 using config-derived head_dim and base=10000. Recomputing v1 at + float32 is also an upgrade over its bfloat16 checkpoint values. + + Baichuan2 Chat also uses NormHead which row-normalizes lm_head during + forward. We apply that once here so the bridge sees the normalized + weights directly without needing NormHead's forward path. + """ + # Pick a real device/dtype by scanning real (non-meta) parameters. + target_device = torch.device("cpu") + params_fn = getattr(hf_model, "parameters", None) + if callable(params_fn): + for param in params_fn(): + if param.device.type != "meta": + target_device = param.device + break + + head_dim = self.cfg.d_model // self.cfg.n_heads + base = 10000.0 + + model_core = getattr(hf_model, "model", None) + if model_core is not None: + for layer in getattr(model_core, "layers", []): + rotary = getattr(getattr(layer, "self_attn", None), "rotary_emb", None) + if rotary is None: + continue + max_seq = getattr(rotary, "max_seq_len_cached", self.cfg.n_ctx or 4096) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, head_dim, 2, device=target_device, dtype=torch.float32) + / head_dim + ) + ) + t = torch.arange(max_seq, device=target_device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + rotary.inv_freq = inv_freq + rotary.cos_cached = emb.cos()[None, None, :, :] + rotary.sin_cached = emb.sin()[None, None, :, :] + + # Normalize NormHead weights (Baichuan2 Chat) + lm_head = getattr(hf_model, "lm_head", None) + if lm_head is not None and hasattr(lm_head, "first_flag"): + w = lm_head.weight.data + lm_head.weight.data = torch.nn.functional.normalize(w, dim=-1) + + def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Split fused W_pack QKV and optionally fold layer norms.""" + fold_ln = getattr(self, "_fold_ln_requested", True) + if not fold_ln: + # Still need to split W_pack into Q/K/V for weight conversions + for i in range(self.cfg.n_layers): + qkv_key = f"blocks.{i}.attn.qkv.weight" + if qkv_key not in state_dict: + continue + w = state_dict[qkv_key] + hidden_size = w.shape[1] + q_w = w[:hidden_size, :] + k_w = w[hidden_size : 2 * hidden_size, :] + v_w = w[2 * hidden_size :, :] + state_dict[f"blocks.{i}.attn.q.weight"] = q_w + state_dict[f"blocks.{i}.attn.k.weight"] = k_w + state_dict[f"blocks.{i}.attn.v.weight"] = v_w + del state_dict[qkv_key] + return state_dict + + for i in range(self.cfg.n_layers): + # --- Fold ln1 into Q/K/V (split from W_pack) --- + qkv_key = f"blocks.{i}.attn.qkv.weight" + ln1_key = f"blocks.{i}.ln1.weight" + if qkv_key in state_dict and ln1_key in state_dict: + ln1_w = state_dict[ln1_key].float() + w = state_dict[qkv_key].float() + orig_dtype = state_dict[qkv_key].dtype + hidden_size = w.shape[1] + + q_w = w[:hidden_size, :] + k_w = w[hidden_size : 2 * hidden_size, :] + v_w = w[2 * hidden_size :, :] + + state_dict[f"blocks.{i}.attn.q.weight"] = (q_w * ln1_w[None, :]).to(orig_dtype) + state_dict[f"blocks.{i}.attn.k.weight"] = (k_w * ln1_w[None, :]).to(orig_dtype) + state_dict[f"blocks.{i}.attn.v.weight"] = (v_w * ln1_w[None, :]).to(orig_dtype) + del state_dict[qkv_key] + state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key]) + + # --- Fold ln2 into MLP gate and up projections --- + ln2_key = f"blocks.{i}.ln2.weight" + if ln2_key in state_dict: + ln2_w = state_dict[ln2_key].float() + for mlp_key in [ + f"blocks.{i}.mlp.gate.weight", + f"blocks.{i}.mlp.in.weight", + ]: + if mlp_key in state_dict: + orig_dtype = state_dict[mlp_key].dtype + state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to( + orig_dtype + ) + state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key]) + + # --- Fold ln_final into unembed --- + ln_final_key = "ln_final.weight" + unembed_key = "unembed.weight" + if ln_final_key in state_dict and unembed_key in state_dict: + ln_w = state_dict[ln_final_key].float() + u_w = state_dict[unembed_key].float() + orig_dtype = state_dict[unembed_key].dtype + if u_w.shape[-1] == ln_w.shape[0]: + state_dict[unembed_key] = (u_w * ln_w[None, :]).to(orig_dtype) + elif u_w.shape[0] == ln_w.shape[0]: + state_dict[unembed_key] = (u_w * ln_w[:, None]).to(orig_dtype) + state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key]) + + return state_dict diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index 9ea753a6b..c0fb70907 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -6,9 +6,9 @@ "min_downloads": 500, "scan_duration_seconds": 4.9 }, - "total_architectures": 48, - "total_models": 9056, - "total_verified": 709, + "total_architectures": 50, + "total_models": 9068, + "total_verified": 711, "models": [ { "architecture_id": "Qwen3NextForCausalLM", @@ -125195,6 +125195,210 @@ "phase4_score": null, "phase7_score": null, "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan2-7B-Chat", + "status": 3, + "verified_date": "2026-04-21", + "metadata": { + "downloads": 89081, + "total_params": null + }, + "note": "Below threshold: P1=0.0% < 100.0% (failed: load_bridge_unprocessed) \u2014 Failed to load unprocessed TransformerBridge: This modeling file requires the following packages that were not found in your environment: bitsandbytes", + "phase1_score": 0.0, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan2-13B-Chat", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 7963, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan-13B-Chat", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 6570, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan2-7B-Base", + "status": 1, + "verified_date": "2026-04-21", + "metadata": { + "downloads": 2007, + "total_params": null + }, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 100.0, + "phase4_score": 94.6, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "sakuraumi/Sakura-13B-Galgame", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1800, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "zxbsmk/NSFW_13B_sft", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1786, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan2-13B-Base", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1773, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "Wuyanzzh/NSFW_13B_sft", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1486, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan-13B-Base", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1295, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "FreedomIntelligence/HuatuoGPT2-7B", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1252, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "DuJinHua/AiMed2", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 951, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaiChuanForCausalLM", + "model_id": "baichuan-inc/Baichuan-7B", + "status": 1, + "verified_date": "2026-04-21", + "metadata": { + "downloads": 50000, + "total_params": null + }, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 100.0, + "phase4_score": 92.0, + "phase7_score": null, + "phase8_score": null } ] } diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index 657c09326..c87d21798 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-04-15T16:28:06.314994", + "last_updated": "2026-04-21T20:10:35.469418", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -11730,6 +11730,126 @@ "notes": "Full verification completed with issues, low text quality", "invalidated": false, "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan2-7B-Chat", + "architecture_id": "BaichuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: load_bridge_unprocessed) \u2014 Failed to load unprocessed TransformerBridge: This modeling file requires the following packages that were not found in your environment: bitsandbytes", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan2-7B-Base", + "architecture_id": "BaichuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: load_bridge_unprocessed) \u2014 Failed to load unprocessed TransformerBridge: This modeling file requires the following packages that were not found in your environment: bitsandbytes", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=74.608353, mean_rel=1.619285", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=78.619270, mean_rel=1.866265", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=nan, mean_rel=nan", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=33.073044, mean_rel=0.316714", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=33.073044, mean_rel=0.316714", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P2=69.2% < 75.0% (failed: generation, generation_with_kv_cache, multiple_generation \u2014 Generation failed: 'NoneType' object is not subscriptable", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan2-7B-Base", + "architecture_id": "BaichuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass); P2=7.7% < 75.0% (failed: generation, gene \u2014 Forward pass failed: Cannot copy out of meta tensor; no data!", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan2-7B-Base", + "architecture_id": "BaichuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null } ] } diff --git a/transformer_lens/tools/model_registry/verify_models.py b/transformer_lens/tools/model_registry/verify_models.py index 1497a1cef..4798db1c6 100644 --- a/transformer_lens/tools/model_registry/verify_models.py +++ b/transformer_lens/tools/model_registry/verify_models.py @@ -60,6 +60,7 @@ # Architectures added via the TransformerBridge system that need trust_remote_code=True. # These are not in the legacy NEED_REMOTE_CODE_MODELS tuple (loading_from_pretrained.py). _BRIDGE_REMOTE_CODE_PREFIXES: tuple[str, ...] = ( + "baichuan-inc/", # BaichuanForCausalLM — ships own modeling_baichuan.py "internlm/", # InternLM2ForCausalLM — ships own modeling_internlm2.py )