From 98811df593747495c30a20aa1fe1d66661b12e13 Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Mon, 20 Apr 2026 13:22:55 -0500 Subject: [PATCH 01/21] 3.0 CI Bugs (#1261) * Fixing `utils` imports * skip gated notebooks on PR from forks * Updating notebooks * Ensure LLaMA only runs when HF_TOKEN is available --- .github/workflows/checks.yml | 15 +- demos/Activation_Patching_in_TL_Demo.ipynb | 34 +- demos/Attribution_Patching_Demo.ipynb | 216 +- demos/Exploratory_Analysis_Demo.ipynb | 4 +- demos/Grokking_Demo.ipynb | 20 +- demos/Head_Detector_Demo.ipynb | 80 +- demos/Interactive_Neuroscope.ipynb | 9 +- demos/LLaMA.ipynb | 2 +- demos/LLaMA2_GPU_Quantized.ipynb | 2 +- demos/Main_Demo.ipynb | 2 +- demos/No_Position_Experiment.ipynb | 2760 ++++++++--------- demos/Othello_GPT.ipynb | 4 +- demos/SVD_Interpreter_Demo.ipynb | 2 +- demos/Santa_Coder.ipynb | 2 +- transformer_lens/ActivationCache.py | 4 +- transformer_lens/HookedEncoderDecoder.py | 2 +- transformer_lens/HookedTransformer.py | 8 +- .../components/abstract_attention.py | 2 +- transformer_lens/components/bert_block.py | 2 +- transformer_lens/components/pos_embed.py | 2 +- transformer_lens/components/t5_block.py | 2 +- .../components/transformer_block.py | 2 +- .../config/HookedTransformerConfig.py | 2 +- transformer_lens/evals.py | 4 +- transformer_lens/head_detector.py | 2 +- transformer_lens/hook_points.py | 2 +- transformer_lens/loading_from_pretrained.py | 2 +- transformer_lens/model_bridge/bridge.py | 2 +- .../model_bridge/sources/transformers.py | 2 +- transformer_lens/patching.py | 2 +- transformer_lens/train.py | 2 +- transformer_lens/utilities/devices.py | 2 +- .../utilities/exploratory_utils.py | 4 +- transformer_lens/weight_processing.py | 2 +- 34 files changed, 1619 insertions(+), 1585 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 86d809f4a..3a2e1389c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -221,6 +221,8 @@ jobs: name: Notebook Checks runs-on: ubuntu-latest timeout-minutes: 30 + env: + HAS_HF_TOKEN: ${{ secrets.HF_TOKEN != '' }} strategy: fail-fast: false matrix: @@ -233,8 +235,6 @@ jobs: # - "Grokking_Demo" - "Head_Detector_Demo" # - "Interactive_Neuroscope" - - "LLaMA" - - "LLaMA2_GPU_Quantized" # Requires quantization libs + too slow for CI timeout - "Main_Demo" # - "No_Position_Experiment" - "Othello_GPT" @@ -242,6 +242,12 @@ jobs: - "Santa_Coder" # - "stable_lm" - "T5" + requires_hf_token: [false] + include: + - notebook: "LLaMA" + requires_hf_token: true + - notebook: "LLaMA2_GPU_Quantized" + requires_hf_token: true steps: - uses: actions/checkout@v3 - name: Add swap space @@ -278,10 +284,13 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} - name: Check Notebook Output Consistency - # Note: currently only checks notebooks we have specifically setup for this + if: ${{ !matrix.requires_hf_token || env.HAS_HF_TOKEN == 'true' }} run: pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/${{ matrix.notebook }}.ipynb env: HF_TOKEN: ${{ secrets.HF_TOKEN }} + - name: Skip (HF_TOKEN not available) + if: ${{ matrix.requires_hf_token && env.HAS_HF_TOKEN != 'true' }} + run: echo "Skipping ${{ matrix.notebook }} — requires HF_TOKEN for gated model access" build-docs: diff --git a/demos/Activation_Patching_in_TL_Demo.ipynb b/demos/Activation_Patching_in_TL_Demo.ipynb index 17fc8f004..cbd7378cb 100644 --- a/demos/Activation_Patching_in_TL_Demo.ipynb +++ b/demos/Activation_Patching_in_TL_Demo.ipynb @@ -126,7 +126,7 @@ "outputs": [], "source": [ "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.model_bridge import TransformerBridge" ] }, @@ -674,7 +674,7 @@ ], "source": [ "# NBVAL_SKIP\n", - "# Heavy patching computation \u2014 too slow for CI without GPU\n", + "# Heavy patching computation — too slow for CI without GPU\n", "every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n", "imshow(every_block_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])" ] @@ -885,7 +885,21 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "# NBVAL_IGNORE_OUTPUT\ndef induction_loss(logits, answer_token_indices=rand_tokens_A):\n seq_len = answer_token_indices.shape[1]\n\n # logits: batch x seq_len x vocab_size\n # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n final_logits = logits[:, -seq_len:-1]\n final_log_probs = final_logits.log_softmax(-1)\n return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\nCLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\nprint(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\nCORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\nprint(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)" + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "def induction_loss(logits, answer_token_indices=rand_tokens_A):\n", + " seq_len = answer_token_indices.shape[1]\n", + "\n", + " # logits: batch x seq_len x vocab_size\n", + " # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n", + " final_logits = logits[:, -seq_len:-1]\n", + " final_log_probs = final_logits.log_softmax(-1)\n", + " return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\n", + "CLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\n", + "print(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\n", + "CORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\n", + "print(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)" + ] }, { "cell_type": "markdown", @@ -899,7 +913,17 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "# NBVAL_SKIP\n# Heavy patching computation \u2014 too slow for CI without GPU\nevery_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\nimshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n\nif DO_SLOW_RUNS:\n every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens_induction[0]))], y=[f\"L{l}H{h}\" for l in range(attn_only.cfg.n_layers) for h in range(attn_only.cfg.n_heads)])" + "source": [ + "# NBVAL_SKIP\n", + "# Heavy patching computation — too slow for CI without GPU\n", + "every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n", + "imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n", + "\n", + "if DO_SLOW_RUNS:\n", + " every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n", + " every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n", + " imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens_induction[0]))], y=[f\"L{l}H{h}\" for l in range(attn_only.cfg.n_layers) for h in range(attn_only.cfg.n_heads)])" + ] }, { "cell_type": "markdown", @@ -1289,4 +1313,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/demos/Attribution_Patching_Demo.ipynb b/demos/Attribution_Patching_Demo.ipynb index 7cc944ea7..8d9fce790 100644 --- a/demos/Attribution_Patching_Demo.ipynb +++ b/demos/Attribution_Patching_Demo.ipynb @@ -179,7 +179,7 @@ "outputs": [], "source": [ "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens import (\n", " ActivationCache,\n", ")\n", @@ -3806,7 +3806,7 @@ "Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n", "Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n", "Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n", - "Top 3th token. Logit: 14.58 Prob: 0.21% Token: | \u00c9|\n", + "Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n", "Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n", "Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n", "Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n", @@ -4236,11 +4236,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d230304b88114f2a9b85f5a48f441ce6", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_ee535e2cfa694be1a7857b1867b8b608", "tabbable": null, "tooltip": null, - "value": "\u2007456k/?\u2007[00:00<00:00,\u200712.7MB/s]" + "value": " 456k/? [00:00<00:00, 12.7MB/s]" } }, "020cf001eb7d496295a325cbc0ee8718": { @@ -4277,7 +4277,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_af67816d50074ae498ef9b600b4175ed", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7ac7b11ef3e34bf1a12926c745e08707", "tabbable": null, "tooltip": null, @@ -4300,11 +4300,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_7a32c6104cdb4eee8dadc248c129040c", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_8c0e7d4b46c14e2bb2167752820a9274", "tabbable": null, "tooltip": null, - "value": "\u20071.36M/?\u2007[00:00<00:00,\u200717.1MB/s]" + "value": " 1.36M/? [00:00<00:00, 17.1MB/s]" } }, "043e0e7fe43744589b7bad2527c2eac0": { @@ -4323,11 +4323,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d3dab66a1c254f07afa02e73e6fd121d", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_a386fa811d524ae08ac67cce5ebf3a15", "tabbable": null, "tooltip": null, - "value": "\u20071.36M/?\u2007[00:00<00:00,\u200720.0MB/s]" + "value": " 1.36M/? [00:00<00:00, 20.0MB/s]" } }, "04d1b3296c75497bb314206d6c7d5341": { @@ -4346,11 +4346,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_8e6cf78296b14bc381f13658ebf99912", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_0ebc4d3f1e94415086e749f4cd41b783", "tabbable": null, "tooltip": null, - "value": "\u20071.04M/?\u2007[00:00<00:00,\u200710.6MB/s]" + "value": " 1.04M/? [00:00<00:00, 10.6MB/s]" } }, "0582e71e725a4851a1905aceaa3c36ae": { @@ -4669,11 +4669,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_2569c461e9144c4c82e856e4533449ba", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_2f332b079a3044fa8ab87f028a7b80b0", "tabbable": null, "tooltip": null, - "value": "\u200748/48\u2007[01:10<00:00,\u2007\u20071.46s/it]" + "value": " 48/48 [01:10<00:00,  1.46s/it]" } }, "0fd7652c5e624ef7b2a36a0b0397f51d": { @@ -4805,11 +4805,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_5cc6434927224f72aadd34dc0e0c2894", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_432ff9b9f3574d69b47e8272dd762923", "tabbable": null, "tooltip": null, - "value": "merges.txt:\u2007" + "value": "merges.txt: " } }, "12f960167a1c417aacdd77ce3a997e35": { @@ -4828,7 +4828,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_8dd28e04200641a9a2a4e5ed241db518", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_912fc8506a6f45e38f1573d27eff6457", "tabbable": null, "tooltip": null, @@ -4875,11 +4875,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_ae5d013bec884f4b97d1852b1fb52432", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_87c101f3cc0f4553870ca9de688b9e83", "tabbable": null, "tooltip": null, - "value": "vocab.json:\u2007" + "value": "vocab.json: " } }, "1775bd14b2104a078aa63991cc11ba85": { @@ -4967,11 +4967,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_6dd11b3c888a461aaa85372b044ccd53", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_dc90adc8272e4e3e929844e2ceef149b", "tabbable": null, "tooltip": null, - "value": "\u2007124/124\u2007[00:00<00:00,\u200754.7kB/s]" + "value": " 124/124 [00:00<00:00, 54.7kB/s]" } }, "180d2ba6e10e4e808eba69a8517d5080": { @@ -4990,7 +4990,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_de109b45a18f42b5aa83f63ef683379f", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_4e44c9999b664ea9bad1b4e360ee76c7", "tabbable": null, "tooltip": null, @@ -5686,7 +5686,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_59190b0bd8e74ee1bab6aea2f931856d", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_c0aa3c04c0a74717b8fc6700213bf579", "tabbable": null, "tooltip": null, @@ -5796,11 +5796,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_700b88d4443848bab341b9d7b00cab54", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_b2e784a339524df682698e606959668e", "tabbable": null, "tooltip": null, - "value": "tokenizer.json:\u2007" + "value": "tokenizer.json: " } }, "31a28b69348b40bfbd14a54380bfb766": { @@ -5819,7 +5819,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_eddca0196bdf4e1eb5605e557bfe597b", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_c762da254b434ffea5b7c35e73009302", "tabbable": null, "tooltip": null, @@ -6314,7 +6314,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_5ed82326612a4505a34bc16d6b0b5fa8", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_69d19b1cf82443ff8994ffd7b156921c", "tabbable": null, "tooltip": null, @@ -6337,11 +6337,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_c9571f91e4894ac0ab6f9433d6dd7258", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_c08b955fa9494958bee9f565c568fc31", "tabbable": null, "tooltip": null, - "value": "tokenizer.json:\u2007" + "value": "tokenizer.json: " } }, "3c377339930f49a5891caeb0639a8360": { @@ -6561,11 +6561,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_68f6906692b447f1acec3cef5772fe5a", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_67cda37b9ae74d3484442b7d3bb19a26", "tabbable": null, "tooltip": null, - "value": "merges.txt:\u2007" + "value": "merges.txt: " } }, "4000e5115c6d48d687ab9b9695a0d826": { @@ -6600,7 +6600,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d94d7a4f5ab34fc9a4a4ee0a07764461", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_0fd7652c5e624ef7b2a36a0b0397f51d", "tabbable": null, "tooltip": null, @@ -6821,11 +6821,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_212c5660764844db8cdc6e3a16099521", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7de33c6a558d40a69a85b3db9e203ae8", "tabbable": null, "tooltip": null, - "value": "model.safetensors:\u2007100%" + "value": "model.safetensors: 100%" } }, "46c907a0ac31481f9147bf22e2ac5864": { @@ -6897,11 +6897,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_91db9856db97451196e16d433896af48", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_4e68d17b4b7a45369d6917887ddd7a28", "tabbable": null, "tooltip": null, - "value": "vocab.json:\u2007" + "value": "vocab.json: " } }, "49be2b480d5847a3af7835c317236280": { @@ -7253,11 +7253,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d7e105c660824d349c4ee17006f04437", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_2848d61b3098431baa1d82fb85f469fe", "tabbable": null, "tooltip": null, - "value": "\u2007548M/548M\u2007[00:18<00:00,\u200760.2MB/s]" + "value": " 548M/548M [00:18<00:00, 60.2MB/s]" } }, "54af6102260d458db54e634c9814aa6f": { @@ -7294,11 +7294,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_6d3de9443ae74b75930d8397e9c7ed9a", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_63c8f7fdcf6346d9b3ca140d0f63f8e4", "tabbable": null, "tooltip": null, - "value": "\u2007144/144\u2007[00:07<00:00,\u200718.78it/s]" + "value": " 144/144 [00:07<00:00, 18.78it/s]" } }, "54daaf323029464b9b67f8a4f53b3002": { @@ -7333,7 +7333,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_39f8ccf31f6c49c7a95c59989236a3cf", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_54af6102260d458db54e634c9814aa6f", "tabbable": null, "tooltip": null, @@ -7653,11 +7653,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_7ffa006c49564aa8ad58f08f48b98955", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_357dcb56edba4564b6ec3051f3e977a5", "tabbable": null, "tooltip": null, - "value": "\u2007144/144\u2007[00:08<00:00,\u200715.78it/s]" + "value": " 144/144 [00:08<00:00, 15.78it/s]" } }, "5f730b9ef10e4b8bb97c4fef1bd7cbb2": { @@ -7789,7 +7789,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_4b378e2fe92a4bb5a0cc2adee8a9372d", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_8e4b9fcabfbc4a37a54f08a99e28b220", "tabbable": null, "tooltip": null, @@ -7838,11 +7838,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d7699d95a0ab4240bfa2754ac81a4dea", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_b92108f127ec4341af59d110b5f991c4", "tabbable": null, "tooltip": null, - "value": "Loading\u2007weights:\u2007100%" + "value": "Loading weights: 100%" } }, "61ffe9ae7ab44e94b5729cae80e49437": { @@ -7998,11 +7998,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_38e4ab89d9ed4f16a2a481054a18977f", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_3673315e0fe741048d2ba304360be671", "tabbable": null, "tooltip": null, - "value": "Loading\u2007weights:\u2007100%" + "value": "Loading weights: 100%" } }, "67cda37b9ae74d3484442b7d3bb19a26": { @@ -8203,11 +8203,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e3abfde7cfd24e938684e179059edd9d", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_293230277eac481e84ab200e7cd5bdc1", "tabbable": null, "tooltip": null, - "value": "\u200726.0/26.0\u2007[00:00<00:00,\u20073.64kB/s]" + "value": " 26.0/26.0 [00:00<00:00, 3.64kB/s]" } }, "6ca82062740348f2bee12629de7f8e2f": { @@ -8350,11 +8350,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_39409ff188e6463bab5bf783828cdbd6", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7bd58f2c5b444ed9b5f21c6b364c9dce", "tabbable": null, "tooltip": null, - "value": "\u2007456k/?\u2007[00:00<00:00,\u20079.34MB/s]" + "value": " 456k/? [00:00<00:00, 9.34MB/s]" } }, "6dd11b3c888a461aaa85372b044ccd53": { @@ -8550,11 +8550,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_c08dfc48a3574ac1ba2e416960d1d3ea", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_3d29acbe122346388e66e0741971a810", "tabbable": null, "tooltip": null, - "value": "generation_config.json:\u2007100%" + "value": "generation_config.json: 100%" } }, "737d22cc16184d6a92cf045c476c7a01": { @@ -8573,11 +8573,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_8f5f5df1b0314449a113f7bb959fa273", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7671d9fee96d4e9f921046a1fb092672", "tabbable": null, "tooltip": null, - "value": "\u2007689/689\u2007[00:00<00:00,\u2007131kB/s]" + "value": " 689/689 [00:00<00:00, 131kB/s]" } }, "73cd80c764784b4197af01198ba6b886": { @@ -8596,7 +8596,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_f796802e863c48138fdcec92f546a372", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_22e77d7546334e038b64cc2c856a6a13", "tabbable": null, "tooltip": null, @@ -8619,11 +8619,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_cb86f659e5ad4c5296c32b97d99d357c", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_8dacdd0b8b5e4433ae4511433eb7df1d", "tabbable": null, "tooltip": null, - "value": "generation_config.json:\u2007100%" + "value": "generation_config.json: 100%" } }, "741e4c3c5c56426c91955f6b0622f629": { @@ -8642,11 +8642,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_5e617862d88745c792f610f6651662f6", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_60e5aae936bb41faad191cf2155f78d3", "tabbable": null, "tooltip": null, - "value": "tokenizer_config.json:\u2007100%" + "value": "tokenizer_config.json: 100%" } }, "75b76bb0ff2f491a8e5febeb8166cbd2": { @@ -8813,11 +8813,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_07149da010c5489696b03653df25cdd2", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_8ccfc9585b794ba293cbe376311c42ba", "tabbable": null, "tooltip": null, - "value": "\u2007144/144\u2007[00:07<00:00,\u200717.62it/s]" + "value": " 144/144 [00:07<00:00, 17.62it/s]" } }, "775073e43a7a4b2f970c810d2a05c73e": { @@ -9414,11 +9414,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_cfee0a94009c461dbdca5b73961f7fbe", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_f732f9391dd14bcaace6fd5c27a8335a", "tabbable": null, "tooltip": null, - "value": "\u2007144/144\u2007[00:07<00:00,\u200718.66it/s]" + "value": " 144/144 [00:07<00:00, 18.66it/s]" } }, "84867a129d0043b4910ac244e8a984df": { @@ -10089,11 +10089,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_a996f052d0ed4d938f0c71fc72e8c1b6", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_2a6fd373a6524eb3ae033ea78d3cb61e", "tabbable": null, "tooltip": null, - "value": "config.json:\u2007100%" + "value": "config.json: 100%" } }, "9157c241f7064a8596e1ffaeb850e59c": { @@ -10205,7 +10205,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_c51496d64234439ebbfec98b59f44803", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_29bf4e8f0e6042b498c6ac50d8fedf68", "tabbable": null, "tooltip": null, @@ -10508,11 +10508,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d95b7b9cf6914acb9d1152502d2ba41b", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_d7c69daa5fa44a6487f5dc66380ec31a", "tabbable": null, "tooltip": null, - "value": "\u2007180/180\u2007[00:10<00:00,\u200718.55it/s]" + "value": " 180/180 [00:10<00:00, 18.55it/s]" } }, "9b50599c4a084d9eb5b8755040c9bd32": { @@ -10547,11 +10547,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_f5bbf314d840422b9e486386de3f5bb6", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_dc68cf2bf3e94df881377a106754a350", "tabbable": null, "tooltip": null, - "value": "\u20071.04M/?\u2007[00:00<00:00,\u200710.5MB/s]" + "value": " 1.04M/? [00:00<00:00, 10.5MB/s]" } }, "9ba87779605c4969bf48b4071a94c630": { @@ -10671,11 +10671,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_363c2c2e96624875af87c420c7e2cf95", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_b732ea0e03674d4384ac0d2dbf2a5f69", "tabbable": null, "tooltip": null, - "value": "config.json:\u2007100%" + "value": "config.json: 100%" } }, "a0df7ae7fcc1441a8c9cca5a80b539b0": { @@ -10694,11 +10694,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_bcaf8ebca41240bd86dde0c617b68f01", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_dd6bf89931a64c63bbcd2cf526835c2d", "tabbable": null, "tooltip": null, - "value": "\u20072160/2160\u2007[02:00<00:00,\u200718.64it/s]" + "value": " 2160/2160 [02:00<00:00, 18.64it/s]" } }, "a1f007e8fb68491daa3ba444ca49f505": { @@ -10809,11 +10809,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_7655d70259dd4dbc8bd4d288f8850b7c", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_eb49abf5e8ea41e69c1307f78fda4a90", "tabbable": null, "tooltip": null, - "value": "\u2007180/180\u2007[00:09<00:00,\u200718.56it/s]" + "value": " 180/180 [00:09<00:00, 18.56it/s]" } }, "a711143026bc46a5b1b7bc3dccca1850": { @@ -10858,11 +10858,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_2970477ecd6545b2bb698748d4019dac", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7c2906dcc3d34ae6846332eb5375cc58", "tabbable": null, "tooltip": null, - "value": "\u2007124/124\u2007[00:00<00:00,\u200773.2kB/s]" + "value": " 124/124 [00:00<00:00, 73.2kB/s]" } }, "a92811787eb84dd19d9ec2fb2eab7eee": { @@ -11381,11 +11381,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_b75438a235a243d49404467d54b63373", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_8f8ffe07f8314800ad637194c8c7d10f", "tabbable": null, "tooltip": null, - "value": "\u2007665/665\u2007[00:00<00:00,\u2007122kB/s]" + "value": " 665/665 [00:00<00:00, 122kB/s]" } }, "b59d1c8089e04592a1b87a7c198d1f6c": { @@ -11481,7 +11481,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_97f6ad1918334a3ab503d4a5da11c9ef", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_0e453235a18e4f5c9008040b1420f718", "tabbable": null, "tooltip": null, @@ -11609,11 +11609,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_45e840f42f8c46d491678fad7820e835", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_e29dd892780a4208b61593998e588e1f", "tabbable": null, "tooltip": null, - "value": "model.safetensors:\u2007100%" + "value": "model.safetensors: 100%" } }, "bb91ef19e455404aac3d283f868f9687": { @@ -12050,11 +12050,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_2452dd39a79742b29964500360f4a478", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_6ca82062740348f2bee12629de7f8e2f", "tabbable": null, "tooltip": null, - "value": "\u2007580/580\u2007[00:00<00:00,\u20075679.21it/s,\u2007Materializing\u2007param=transformer.wte.weight]" + "value": " 580/580 [00:00<00:00, 5679.21it/s, Materializing param=transformer.wte.weight]" } }, "c455478a557645b29777950e364a5006": { @@ -12258,11 +12258,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_cb791fd8c7ae49079f9162125e98ff79", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_52cd22d9c796437692165d3f3ed48e82", "tabbable": null, "tooltip": null, - "value": "\u20072160/2160\u2007[02:00<00:00,\u200715.47it/s]" + "value": " 2160/2160 [02:00<00:00, 15.47it/s]" } }, "c5bca44eefb940d39fba70d4fff71571": { @@ -12358,11 +12358,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_f44f01a296fb4d198718574f1f802ba2", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_c6b1b39c22564a59bc5b950d7a46708f", "tabbable": null, "tooltip": null, - "value": "\u200726.0/26.0\u2007[00:00<00:00,\u20074.24kB/s]" + "value": " 26.0/26.0 [00:00<00:00, 4.24kB/s]" } }, "c6b1b39c22564a59bc5b950d7a46708f": { @@ -12417,11 +12417,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_7d0ba24e89554742a562ce50135447f0", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_502e8cbf7b7c47b480b8a85405fc24cf", "tabbable": null, "tooltip": null, - "value": "\u20072160/2160\u2007[02:00<00:00,\u200717.95it/s]" + "value": " 2160/2160 [02:00<00:00, 17.95it/s]" } }, "c836782daaa847248009a626db347182": { @@ -12562,11 +12562,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_3d83725b10254139a51cb688a495459d", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_020cf001eb7d496295a325cbc0ee8718", "tabbable": null, "tooltip": null, - "value": "\u20072160/2160\u2007[02:00<00:00,\u200718.75it/s]" + "value": " 2160/2160 [02:00<00:00, 18.75it/s]" } }, "cac816368dee481a9fee6b196a2b16d6": { @@ -12760,7 +12760,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_72035846c38c4a439583cc5e974c0a52", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7aebe1ba464949849da0e182d90d0669", "tabbable": null, "tooltip": null, @@ -12986,11 +12986,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_8ab4ec47b46e401883c185417c452f17", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_fdb4d541f8414fe4bee9265554cd7522", "tabbable": null, "tooltip": null, - "value": "tokenizer_config.json:\u2007100%" + "value": "tokenizer_config.json: 100%" } }, "d3dab66a1c254f07afa02e73e6fd121d": { @@ -13692,11 +13692,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_b41e54c58689400b86fc0dbf18e4bbaa", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_98bce027ddf046c29bbbb03c6e9b1de3", "tabbable": null, "tooltip": null, - "value": "\u2007144/144\u2007[00:08<00:00,\u200718.70it/s]" + "value": " 144/144 [00:08<00:00, 18.70it/s]" } }, "df2f48a5055b4cb7a9db8115c407fd8c": { @@ -13892,11 +13892,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_29376b858cd4489a8fcefc2b096df1e5", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_6f7575088f10441c87ded2f68ac37e9f", "tabbable": null, "tooltip": null, - "value": "\u20072160/2160\u2007[01:59<00:00,\u200718.72it/s]" + "value": " 2160/2160 [01:59<00:00, 18.72it/s]" } }, "e3abfde7cfd24e938684e179059edd9d": { @@ -14219,11 +14219,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_db91637067cb428c94674237eabda8f7", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_dd98180b8cba40aa82fc6221dd4676d0", "tabbable": null, "tooltip": null, - "value": "\u2007180/180\u2007[00:09<00:00,\u200717.32it/s]" + "value": " 180/180 [00:09<00:00, 17.32it/s]" } }, "eb49abf5e8ea41e69c1307f78fda4a90": { @@ -14764,7 +14764,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e2c844c07e434e718186afaf72312371", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_3704e0be72444fd9a07038fbe1b19156", "tabbable": null, "tooltip": null, @@ -14840,11 +14840,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_61ffe9ae7ab44e94b5729cae80e49437", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_a314648aea3d40f89ac931c47f81c9e6", "tabbable": null, "tooltip": null, - "value": "\u2007148/148\u2007[00:00<00:00,\u20075498.19it/s,\u2007Materializing\u2007param=transformer.wte.weight]" + "value": " 148/148 [00:00<00:00, 5498.19it/s, Materializing param=transformer.wte.weight]" } }, "fd150a5176074e959dfa52a35770b5f0": { @@ -14863,11 +14863,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_a05625d67e634f2a80c65cdcfcbe8f8c", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_d1e752f79bbc40ddbae7a02895c9b74e", "tabbable": null, "tooltip": null, - "value": "\u20076.43G/6.43G\u2007[04:52<00:00,\u2007111MB/s]" + "value": " 6.43G/6.43G [04:52<00:00, 111MB/s]" } }, "fdb4d541f8414fe4bee9265554cd7522": { diff --git a/demos/Exploratory_Analysis_Demo.ipynb b/demos/Exploratory_Analysis_Demo.ipynb index c2c86cea8..097f25c10 100644 --- a/demos/Exploratory_Analysis_Demo.ipynb +++ b/demos/Exploratory_Analysis_Demo.ipynb @@ -117,7 +117,7 @@ "from IPython.display import HTML, IFrame\n", "from jaxtyping import Float\n", "\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens import ActivationCache\n", "from transformer_lens.model_bridge import TransformerBridge" ] @@ -1883,4 +1883,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/demos/Grokking_Demo.ipynb b/demos/Grokking_Demo.ipynb index 7b7fe5243..09af51e9a 100644 --- a/demos/Grokking_Demo.ipynb +++ b/demos/Grokking_Demo.ipynb @@ -138,7 +138,7 @@ "outputs": [], "source": [ "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import (\n", " HookedRootModule,\n", " HookPoint,\n", @@ -2921,10 +2921,10 @@ "evalue": "name 'train_losses' is not defined", "output_type": "error", "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", - "\u001B[0;32m/tmp/ipykernel_1229617/2975677256.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mneel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mplot\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0mnpx\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m \u001B[0mfig\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnpx\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mline\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mtrain_losses\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;36m100\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtest_losses\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;36m100\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mx\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0marange\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtrain_losses\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m100\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mxaxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m\"Epoch\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0myaxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m\"Loss\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlog_y\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtitle\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m\"Training Curve for Modular Addition\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mline_labels\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m'train'\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m'test'\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtoggle_x\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtoggle_y\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mreturn_fig\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 3\u001B[0m \u001B[0madd_lines\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mfig\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", - "\u001B[0;31mNameError\u001B[0m: name 'train_losses' is not defined" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_1229617/2975677256.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mneel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnpx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnpx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtrain_losses\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_losses\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_losses\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxaxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Epoch\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myaxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Loss\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_y\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtitle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Training Curve for Modular Addition\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mline_labels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'test'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoggle_x\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoggle_y\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_fig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0madd_lines\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'train_losses' is not defined" ] } ], @@ -3500,11 +3500,11 @@ "evalue": "Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1", "output_type": "error", "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mRuntimeError\u001B[0m Traceback (most recent call last)", - "\u001B[0;32m/tmp/ipykernel_1215793/3004607503.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mprint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mloss_fn\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mall_logits\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlabels\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m", - "\u001B[0;32m/tmp/ipykernel_1215793/4096650173.py\u001B[0m in \u001B[0;36mloss_fn\u001B[0;34m(logits, labels)\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mlogits\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mlogits\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mto\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mfloat64\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0mlog_probs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mlogits\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlog_softmax\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdim\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 6\u001B[0;31m \u001B[0mcorrect_log_probs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mlog_probs\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgather\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdim\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mindex\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mlabels\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 7\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0;34m-\u001B[0m\u001B[0mcorrect_log_probs\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmean\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 8\u001B[0m \u001B[0mtrain_logits\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtrain_data\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", - "\u001B[0;31mRuntimeError\u001B[0m: Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_1215793/3004607503.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/tmp/ipykernel_1215793/4096650173.py\u001b[0m in \u001b[0;36mloss_fn\u001b[0;34m(logits, labels)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlog_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mcorrect_log_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlog_probs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcorrect_log_probs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mtrain_logits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1" ] } ], diff --git a/demos/Head_Detector_Demo.ipynb b/demos/Head_Detector_Demo.ipynb index 5354abeb8..2caa4231f 100644 --- a/demos/Head_Detector_Demo.ipynb +++ b/demos/Head_Detector_Demo.ipynb @@ -321,7 +321,7 @@ "import numpy as np\n", "import torch\n", "\n", - "# from transformer_lens.utils import is_lower_triangular, is_square\n", + "# from transformer_lens.utilities import is_lower_triangular, is_square\n", "\n", "HeadName = Literal[\"previous_token_head\", \"duplicate_token_head\", \"induction_head\"]\n", "HEAD_NAMES = cast(List[HeadName], get_args(HeadName))\n", @@ -393,7 +393,7 @@ " --------\n", " .. code-block:: python\n", "\n", - " >>> from transformer_lens import utils\n", + " >>> from transformer_lens import utilities as utils\n", " >>> from transformer_lens.model_bridge import TransformerBridge\n", " >>> from transformer_lens.head_detector import detect_head\n", " >>> import plotly.express as px\n", @@ -1476,7 +1476,7 @@ "output_type": "stream", "text": [ "\r", - " 3%|\u258e | 3/100 [00:00<00:03, 27.52it/s]" + " 3%|▎ | 3/100 [00:00<00:03, 27.52it/s]" ] }, { @@ -1484,7 +1484,7 @@ "output_type": "stream", "text": [ "\r", - " 6%|\u258c | 6/100 [00:00<00:03, 27.79it/s]" + " 6%|▌ | 6/100 [00:00<00:03, 27.79it/s]" ] }, { @@ -1492,7 +1492,7 @@ "output_type": "stream", "text": [ "\r", - " 9%|\u2589 | 9/100 [00:00<00:03, 27.62it/s]" + " 9%|▉ | 9/100 [00:00<00:03, 27.62it/s]" ] }, { @@ -1500,7 +1500,7 @@ "output_type": "stream", "text": [ "\r", - " 12%|\u2588\u258f | 12/100 [00:00<00:03, 27.87it/s]" + " 12%|█▏ | 12/100 [00:00<00:03, 27.87it/s]" ] }, { @@ -1508,7 +1508,7 @@ "output_type": "stream", "text": [ "\r", - " 15%|\u2588\u258c | 15/100 [00:00<00:03, 22.61it/s]" + " 15%|█▌ | 15/100 [00:00<00:03, 22.61it/s]" ] }, { @@ -1516,7 +1516,7 @@ "output_type": "stream", "text": [ "\r", - " 18%|\u2588\u258a | 18/100 [00:00<00:03, 23.94it/s]" + " 18%|█▊ | 18/100 [00:00<00:03, 23.94it/s]" ] }, { @@ -1524,7 +1524,7 @@ "output_type": "stream", "text": [ "\r", - " 21%|\u2588\u2588 | 21/100 [00:00<00:03, 25.03it/s]" + " 21%|██ | 21/100 [00:00<00:03, 25.03it/s]" ] }, { @@ -1532,7 +1532,7 @@ "output_type": "stream", "text": [ "\r", - " 24%|\u2588\u2588\u258d | 24/100 [00:00<00:02, 26.01it/s]" + " 24%|██▍ | 24/100 [00:00<00:02, 26.01it/s]" ] }, { @@ -1540,7 +1540,7 @@ "output_type": "stream", "text": [ "\r", - " 27%|\u2588\u2588\u258b | 27/100 [00:01<00:02, 26.76it/s]" + " 27%|██▋ | 27/100 [00:01<00:02, 26.76it/s]" ] }, { @@ -1548,7 +1548,7 @@ "output_type": "stream", "text": [ "\r", - " 30%|\u2588\u2588\u2588 | 30/100 [00:01<00:02, 27.30it/s]" + " 30%|███ | 30/100 [00:01<00:02, 27.30it/s]" ] }, { @@ -1556,7 +1556,7 @@ "output_type": "stream", "text": [ "\r", - " 33%|\u2588\u2588\u2588\u258e | 33/100 [00:01<00:02, 27.58it/s]" + " 33%|███▎ | 33/100 [00:01<00:02, 27.58it/s]" ] }, { @@ -1564,7 +1564,7 @@ "output_type": "stream", "text": [ "\r", - " 36%|\u2588\u2588\u2588\u258c | 36/100 [00:01<00:02, 27.98it/s]" + " 36%|███▌ | 36/100 [00:01<00:02, 27.98it/s]" ] }, { @@ -1572,7 +1572,7 @@ "output_type": "stream", "text": [ "\r", - " 39%|\u2588\u2588\u2588\u2589 | 39/100 [00:01<00:02, 23.28it/s]" + " 39%|███▉ | 39/100 [00:01<00:02, 23.28it/s]" ] }, { @@ -1580,7 +1580,7 @@ "output_type": "stream", "text": [ "\r", - " 42%|\u2588\u2588\u2588\u2588\u258f | 42/100 [00:01<00:02, 24.58it/s]" + " 42%|████▏ | 42/100 [00:01<00:02, 24.58it/s]" ] }, { @@ -1588,7 +1588,7 @@ "output_type": "stream", "text": [ "\r", - " 45%|\u2588\u2588\u2588\u2588\u258c | 45/100 [00:01<00:02, 25.73it/s]" + " 45%|████▌ | 45/100 [00:01<00:02, 25.73it/s]" ] }, { @@ -1596,7 +1596,7 @@ "output_type": "stream", "text": [ "\r", - " 48%|\u2588\u2588\u2588\u2588\u258a | 48/100 [00:01<00:01, 26.74it/s]" + " 48%|████▊ | 48/100 [00:01<00:01, 26.74it/s]" ] }, { @@ -1604,7 +1604,7 @@ "output_type": "stream", "text": [ "\r", - " 51%|\u2588\u2588\u2588\u2588\u2588 | 51/100 [00:01<00:01, 27.57it/s]" + " 51%|█████ | 51/100 [00:01<00:01, 27.57it/s]" ] }, { @@ -1612,7 +1612,7 @@ "output_type": "stream", "text": [ "\r", - " 54%|\u2588\u2588\u2588\u2588\u2588\u258d | 54/100 [00:02<00:01, 27.88it/s]" + " 54%|█████▍ | 54/100 [00:02<00:01, 27.88it/s]" ] }, { @@ -1620,7 +1620,7 @@ "output_type": "stream", "text": [ "\r", - " 57%|\u2588\u2588\u2588\u2588\u2588\u258b | 57/100 [00:02<00:01, 28.25it/s]" + " 57%|█████▋ | 57/100 [00:02<00:01, 28.25it/s]" ] }, { @@ -1628,7 +1628,7 @@ "output_type": "stream", "text": [ "\r", - " 60%|\u2588\u2588\u2588\u2588\u2588\u2588 | 60/100 [00:02<00:01, 28.54it/s]" + " 60%|██████ | 60/100 [00:02<00:01, 28.54it/s]" ] }, { @@ -1636,7 +1636,7 @@ "output_type": "stream", "text": [ "\r", - " 63%|\u2588\u2588\u2588\u2588\u2588\u2588\u258e | 63/100 [00:02<00:01, 28.49it/s]" + " 63%|██████▎ | 63/100 [00:02<00:01, 28.49it/s]" ] }, { @@ -1644,7 +1644,7 @@ "output_type": "stream", "text": [ "\r", - " 66%|\u2588\u2588\u2588\u2588\u2588\u2588\u258c | 66/100 [00:02<00:01, 24.36it/s]" + " 66%|██████▌ | 66/100 [00:02<00:01, 24.36it/s]" ] }, { @@ -1652,7 +1652,7 @@ "output_type": "stream", "text": [ "\r", - " 69%|\u2588\u2588\u2588\u2588\u2588\u2588\u2589 | 69/100 [00:02<00:01, 25.76it/s]" + " 69%|██████▉ | 69/100 [00:02<00:01, 25.76it/s]" ] }, { @@ -1660,7 +1660,7 @@ "output_type": "stream", "text": [ "\r", - " 72%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258f | 72/100 [00:02<00:01, 26.88it/s]" + " 72%|███████▏ | 72/100 [00:02<00:01, 26.88it/s]" ] }, { @@ -1668,7 +1668,7 @@ "output_type": "stream", "text": [ "\r", - " 75%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258c | 75/100 [00:02<00:00, 27.68it/s]" + " 75%|███████▌ | 75/100 [00:02<00:00, 27.68it/s]" ] }, { @@ -1676,7 +1676,7 @@ "output_type": "stream", "text": [ "\r", - " 78%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258a | 78/100 [00:02<00:00, 28.29it/s]" + " 78%|███████▊ | 78/100 [00:02<00:00, 28.29it/s]" ] }, { @@ -1684,7 +1684,7 @@ "output_type": "stream", "text": [ "\r", - " 81%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 81/100 [00:03<00:00, 28.48it/s]" + " 81%|████████ | 81/100 [00:03<00:00, 28.48it/s]" ] }, { @@ -1692,7 +1692,7 @@ "output_type": "stream", "text": [ "\r", - " 84%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258d | 84/100 [00:03<00:00, 27.51it/s]" + " 84%|████████▍ | 84/100 [00:03<00:00, 27.51it/s]" ] }, { @@ -1700,7 +1700,7 @@ "output_type": "stream", "text": [ "\r", - " 87%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258b | 87/100 [00:03<00:00, 27.90it/s]" + " 87%|████████▋ | 87/100 [00:03<00:00, 27.90it/s]" ] }, { @@ -1708,7 +1708,7 @@ "output_type": "stream", "text": [ "\r", - " 90%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 90/100 [00:03<00:00, 23.68it/s]" + " 90%|█████████ | 90/100 [00:03<00:00, 23.68it/s]" ] }, { @@ -1716,7 +1716,7 @@ "output_type": "stream", "text": [ "\r", - " 93%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258e| 93/100 [00:03<00:00, 25.09it/s]" + " 93%|█████████▎| 93/100 [00:03<00:00, 25.09it/s]" ] }, { @@ -1724,7 +1724,7 @@ "output_type": "stream", "text": [ "\r", - " 96%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258c| 96/100 [00:03<00:00, 26.30it/s]" + " 96%|█████████▌| 96/100 [00:03<00:00, 26.30it/s]" ] }, { @@ -1732,7 +1732,7 @@ "output_type": "stream", "text": [ "\r", - " 99%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2589| 99/100 [00:03<00:00, 27.08it/s]" + " 99%|█████████▉| 99/100 [00:03<00:00, 27.08it/s]" ] }, { @@ -1740,7 +1740,7 @@ "output_type": "stream", "text": [ "\r", - "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 100/100 [00:03<00:00, 26.54it/s]" + "100%|██████████| 100/100 [00:03<00:00, 26.54it/s]" ] }, { @@ -2217,11 +2217,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e16620f692214737b99cd956859b0c34", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_fc412d50f34a4568a7c286ed2be51fc9", "tabbable": null, "tooltip": null, - "value": "\u2007148/148\u2007[00:00<00:00,\u20075290.11it/s,\u2007Materializing\u2007param=transformer.wte.weight]" + "value": " 148/148 [00:00<00:00, 5290.11it/s, Materializing param=transformer.wte.weight]" } }, "6662c4f913f74e6181469daa1b9e5ed3": { @@ -2293,11 +2293,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e3770c75a2f74f2e966c9b7b8b397bf2", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_535051899ad84087afbaf824a2f7bb73", "tabbable": null, "tooltip": null, - "value": "Loading\u2007weights:\u2007100%" + "value": "Loading weights: 100%" } }, "ded58218d5854ad7be3ffcc9f0165c31": { diff --git a/demos/Interactive_Neuroscope.ipynb b/demos/Interactive_Neuroscope.ipynb index eb4931de1..c372aa4b0 100644 --- a/demos/Interactive_Neuroscope.ipynb +++ b/demos/Interactive_Neuroscope.ipynb @@ -75,7 +75,7 @@ "source": [ "import gradio as gr\n", "from transformer_lens.model_bridge import TransformerBridge\n", - "from transformer_lens.utils import to_numpy\n", + "from transformer_lens.utilities import to_numpy\n", "from IPython.display import HTML" ] }, @@ -446,7 +446,10 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "# NBVAL_SKIP\ndemo.launch(share=True, height=1000)" + "source": [ + "# NBVAL_SKIP\n", + "demo.launch(share=True, height=1000)" + ] } ], "metadata": { @@ -471,4 +474,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/demos/LLaMA.ipynb b/demos/LLaMA.ipynb index faf1fed52..61d0e54d3 100644 --- a/demos/LLaMA.ipynb +++ b/demos/LLaMA.ipynb @@ -99,7 +99,7 @@ "from jaxtyping import Float\n", "\n", "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import (\n", " HookPoint,\n", ") # Hooking utilities\n", diff --git a/demos/LLaMA2_GPU_Quantized.ipynb b/demos/LLaMA2_GPU_Quantized.ipynb index 3129e6489..4722428a9 100644 --- a/demos/LLaMA2_GPU_Quantized.ipynb +++ b/demos/LLaMA2_GPU_Quantized.ipynb @@ -123,7 +123,7 @@ "from jaxtyping import Float\n", "\n", "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import (\n", " HookPoint,\n", ") # Hooking utilities\n", diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index 543d3dcac..3c4e72dc5 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -133,7 +133,7 @@ "outputs": [], "source": [ "# import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import (\n", " HookPoint,\n", ") # Hooking utilities\n", diff --git a/demos/No_Position_Experiment.ipynb b/demos/No_Position_Experiment.ipynb index 95fe03659..979c99e42 100644 --- a/demos/No_Position_Experiment.ipynb +++ b/demos/No_Position_Experiment.ipynb @@ -1,1383 +1,1383 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - " \"Open\n", - "" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Introduction\n", - "\n", - "The accompanying notebook to my [real-time research](https://www.youtube.com/watch?v=yo4QvDn-vsU) video. Trains a model with no positional embeddings to predict the previous token, and makes a start at analysing what's going on there!\n", - "\n", - "EDIT: The loss spikes were due to the learning rate being max(step/100, 1.0) not min! Thanks to MadHatter for catching that." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running as a Jupyter notebook - intended for development only!\n" - ] - } - ], - "source": [ - "# NBVAL_IGNORE_OUTPUT\n", - "import os\n", - "\n", - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", - "DEVELOPMENT_MODE = False\n", - "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", - "try:\n", - " import google.colab\n", - "\n", - " IN_COLAB = True\n", - " print(\"Running as a Colab notebook\")\n", - "except:\n", - " IN_COLAB = False\n", - " print(\"Running as a Jupyter notebook - intended for development only!\")\n", - "\n", - "if IN_COLAB or IN_GITHUB:\n", - " %pip install einops\n", - " %pip install transformer_lens@v1.15.0\n", - "\n", - " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", - " # # Install another version of node that makes PySvelte work way faster\n", - " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", - "\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig\n", - "import torch\n", - "import numpy as np\n", - "import plotly.express as px\n", - "import plotly.io as pio\n", - "\n", - "pio.renderers.default = \"colab\"\n", - "import tqdm.auto as tqdm\n", - "import einops\n", - "from transformer_lens.utils import to_numpy\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Some plotting code. Wrappers around Plotly, not important to understand." - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [], - "source": [ - "def line(tensor, line_labels=None, yaxis=\"\", xaxis=\"\", **kwargs):\n", - " tensor = to_numpy(tensor)\n", - " labels = {\"y\": yaxis, \"x\": xaxis}\n", - " fig = px.line(tensor, labels=labels, **kwargs)\n", - " if line_labels:\n", - " for c, label in enumerate(line_labels):\n", - " fig.data[c].name = label\n", - " fig.show()\n", - "\n", - "\n", - "def imshow(tensor, yaxis=\"\", xaxis=\"\", **kwargs):\n", - " tensor = to_numpy(tensor)\n", - " plot_kwargs = {\n", - " \"color_continuous_scale\": \"RdBu\",\n", - " \"color_continuous_midpoint\": 0.0,\n", - " \"labels\": {\"x\": xaxis, \"y\": yaxis},\n", - " }\n", - " plot_kwargs.update(kwargs)\n", - " px.imshow(tensor, **plot_kwargs).show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Model Training" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Defining the Model" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "cfg = HookedTransformerConfig(\n", - " n_layers=2,\n", - " d_model=64,\n", - " d_head=64,\n", - " n_heads=1,\n", - " d_mlp=256,\n", - " d_vocab=300,\n", - " n_ctx=50,\n", - " act_fn=\"relu\",\n", - " normalization_type=\"LN\",\n", - " device=device,\n", - ")\n", - "model = HookedTransformer(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [], - "source": [ - "def deactivate_position(model):\n", - " model.pos_embed.W_pos.data[:] = 0.0\n", - " model.pos_embed.W_pos.requires_grad = False\n", - "\n", - "\n", - "deactivate_position(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "HookedTransformer(\n", - " (embed): Embed()\n", - " (hook_embed): HookPoint()\n", - " (pos_embed): PosEmbed()\n", - " (hook_pos_embed): HookPoint()\n", - " (blocks): ModuleList(\n", - " (0-1): 2 x TransformerBlock(\n", - " (ln1): LayerNorm(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (ln2): LayerNorm(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (attn): Attention(\n", - " (hook_k): HookPoint()\n", - " (hook_q): HookPoint()\n", - " (hook_v): HookPoint()\n", - " (hook_z): HookPoint()\n", - " (hook_attn_scores): HookPoint()\n", - " (hook_pattern): HookPoint()\n", - " (hook_result): HookPoint()\n", - " )\n", - " (mlp): MLP(\n", - " (hook_pre): HookPoint()\n", - " (hook_post): HookPoint()\n", - " )\n", - " (hook_attn_in): HookPoint()\n", - " (hook_q_input): HookPoint()\n", - " (hook_k_input): HookPoint()\n", - " (hook_v_input): HookPoint()\n", - " (hook_mlp_in): HookPoint()\n", - " (hook_attn_out): HookPoint()\n", - " (hook_mlp_out): HookPoint()\n", - " (hook_resid_pre): HookPoint()\n", - " (hook_resid_mid): HookPoint()\n", - " (hook_resid_post): HookPoint()\n", - " )\n", - " )\n", - " (ln_final): LayerNorm(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (unembed): Unembed()\n", - ")\n" - ] - } - ], - "source": [ - "print(model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Define data + Loss function" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ 0, 93, 34, 155, 274, 116, 114, 248, 68, 3, 298, 83, 194, 20,\n", - " 8, 133, 32, 66, 62, 73, 210, 273, 46, 243, 104, 232, 161, 125,\n", - " 123, 251, 7, 4, 115, 127, 21, 1, 89, 142, 6, 15, 298, 251,\n", - " 88, 229, 108, 114, 23, 88, 3, 265],\n", - " [ 0, 118, 46, 274, 105, 268, 131, 35, 19, 58, 226, 278, 27, 25,\n", - " 276, 180, 164, 4, 95, 27, 74, 201, 105, 65, 80, 185, 44, 258,\n", - " 105, 60, 58, 47, 126, 60, 294, 253, 258, 136, 29, 101, 258, 77,\n", - " 80, 180, 159, 169, 122, 117, 27, 194]])\n" - ] - } - ], - "source": [ - "def make_data_generator(cfg, batch_size, seed=123, incl_bos_token=True):\n", - " torch.manual_seed(seed)\n", - " while True:\n", - " x = torch.randint(1, cfg.d_vocab, (batch_size, cfg.n_ctx))\n", - " if incl_bos_token:\n", - " x[:, 0] = 0\n", - " yield x\n", - "\n", - "\n", - "data_generator = make_data_generator(cfg, 2)\n", - "print(next(data_generator))" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [], - "source": [ - "def loss_fn(logits, tokens, per_token=False):\n", - " # logit shape: [batch, pos, vocab]\n", - " # token shape: [batch, pos]\n", - " logits = logits[:, 1:]\n", - " tokens = tokens[:, :-1]\n", - " log_probs = logits.log_softmax(-1)\n", - " correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]\n", - " if per_token:\n", - " return -correct_log_probs\n", - " else:\n", - " return -correct_log_probs.mean()" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[0.0004, 0.0003, 0.0031, 0.0005]])\n", - "tensor(0.0011)\n" - ] - } - ], - "source": [ - "# Test the loss function works\n", - "test_tokens = torch.arange(5)[None, :]\n", - "test_logits = torch.randn(1, 5, 10)\n", - "test_logits[:, 1, 0] = 10.0\n", - "test_logits[:, 2, 1] = 10.0\n", - "test_logits[:, 3, 2] = 10.0\n", - "test_logits[:, 4, 3] = 10.0\n", - "print(loss_fn(test_logits, test_tokens, per_token=True))\n", - "print(loss_fn(test_logits, test_tokens, per_token=False))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Setup Optimizer\n" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [], - "source": [ - "batch_size = 256\n", - "num_epochs = 4000\n", - "lr = 1e-4\n", - "betas = (0.9, 0.95)\n", - "max_grad_norm = 1.0\n", - "wd = 0.1\n", - "optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)\n", - "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))\n", - "\n", - "data_loader = make_data_generator(cfg, batch_size)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model Training" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "122c183908104b04a600bfe4aca9f009", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/4000 [00:00\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "losses = []\n", - "for epoch in tqdm.tqdm(range(num_epochs)):\n", - " tokens = next(data_loader)\n", - " tokens = tokens.to(device)\n", - " logits = model(tokens)\n", - " loss = loss_fn(logits, tokens)\n", - " loss.backward()\n", - " if max_grad_norm is not None:\n", - " torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n", - " optimizer.step()\n", - " optimizer.zero_grad()\n", - " scheduler.step()\n", - " losses.append(loss.item())\n", - " if epoch % 100 == 0:\n", - " print(f\"Epoch {epoch}: {loss.item()}\")\n", - "px.line(losses, labels={\"x\": \"Epoch\", \"y\": \"Loss\"})" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [], - "source": [ - "# torch.save(model.state_dict(), \"no_pos_experiment_state_dict_v0.pth\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Model Interpretability" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(0.)" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.pos_embed.W_pos.norm()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Look at attention patterns" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loss: 0.003689224598929286\n" - ] - } - ], - "source": [ - "big_data_loader = make_data_generator(cfg, 4000)\n", - "big_tokens = next(big_data_loader)\n", - "big_tokens = big_tokens.to(device)\n", - "logits, cache = model.run_with_cache(big_tokens)\n", - "print(\"Loss:\", loss_fn(logits, big_tokens).item())" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized']\n" - ] - } - ], - "source": [ - "print(cache)" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([4000, 1, 50, 50])" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cache[\"blocks.0.attn.hook_pattern\"].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "batch_index = 0\n", - "tokens = big_tokens[batch_index]\n", - "imshow(\n", - " to_numpy(cache[\"attn\", 0].mean([0, 1])),\n", - " title=\"Layer 0 Attention Pattern\",\n", - " height=500,\n", - " width=500,\n", - ")\n", - "imshow(\n", - " to_numpy(cache[\"attn\", 1].mean([0, 1])),\n", - " title=\"Layer 1 Attention Pattern\",\n", - " height=500,\n", - " width=500,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Look at how different bits of the model directly contribute to the logits" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([5, 4000, 50, 64])\n" - ] - } - ], - "source": [ - "resid_components = [\n", - " cache[\"embed\"],\n", - " cache[\"attn_out\", 0],\n", - " cache[\"mlp_out\", 0],\n", - " cache[\"attn_out\", 1],\n", - " cache[\"mlp_out\", 1],\n", - "]\n", - "labels = [\"embed\", \"A0\", \"M0\", \"A1\", \"M2\"]\n", - "resid_stack = torch.stack(resid_components, 0)\n", - "resid_stack = resid_stack - resid_stack.mean(-1, keepdim=True)\n", - "print(resid_stack.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([5, 50, 300])\n" - ] - } - ], - "source": [ - "fold_W_U = model.ln_final.w[:, None] * model.unembed.W_U\n", - "logit_components = resid_stack[:, batch_index] @ fold_W_U / cache[\"scale\"][batch_index]\n", - "print(logit_components.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "logit_components = logit_components - logit_components.mean(-1, keepdim=True)\n", - "line(\n", - " logit_components[:, torch.arange(1, model.cfg.n_ctx).to(device), tokens[:-1]].T,\n", - " line_labels=labels,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Folding In LayerNorm" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [], - "source": [ - "analysis_cfg = HookedTransformerConfig(\n", - " n_layers=2,\n", - " d_model=64,\n", - " d_head=64,\n", - " n_heads=1,\n", - " d_mlp=256,\n", - " d_vocab=300,\n", - " n_ctx=50,\n", - " act_fn=\"relu\",\n", - " normalization_type=\"LNPre\",\n", - " init_weights=False,\n", - ")\n", - "analysis_model = HookedTransformer(analysis_cfg)\n", - "state_dict = model.state_dict()\n", - "analysis_model.load_and_process_state_dict(\n", - " state_dict, fold_ln=True, center_writing_weights=True, center_unembed=True\n", - ")\n", - "deactivate_position(analysis_model)" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [], - "source": [ - "# analysis_model()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Understand Attn 0\n" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "QK = model.W_E @ model.W_Q[0, 0] @ model.W_K[0, 0].T @ model.W_E.T\n", - "imshow(QK, yaxis=\"Query\", xaxis=\"Key\")" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "OV = model.W_E @ model.W_V[0, 0] @ model.W_O[0, 0] @ model.W_in[0]\n", - "imshow(OV, yaxis=\"Input Vocab\", xaxis=\"Neuron\")" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "line(OV[:, torch.randint(0, 256, (5,))])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Understand MLP 0" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(cache[\"post\", 0][batch_index], yaxis=\"Pos\", xaxis=\"Neuron\")\n", - "imshow(cache[\"post\", 0].mean(0), yaxis=\"Pos\", xaxis=\"Neuron\")\n", - "imshow((cache[\"post\", 0] > 0).float()[batch_index], yaxis=\"Pos\", xaxis=\"Neuron\")\n", - "imshow((cache[\"post\", 0] > 0).float().mean(0), yaxis=\"Pos\", xaxis=\"Neuron\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Understand Attn 1" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Understand MLP 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Experiment" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Baseline loss: 0.0036276562605053186\n" - ] - } - ], - "source": [ - "new_token_batch = next(big_data_loader).to(device)\n", - "baseline_loss = loss_fn(model(new_token_batch), new_token_batch).item()\n", - "print(\"Baseline loss:\", baseline_loss)" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hook_embed 10.975448608398438\n", - "blocks.0.ln1.hook_scale 0.4518754482269287\n", - "blocks.0.ln1.hook_normalized 2.4589312076568604\n", - "blocks.0.ln2.hook_scale 0.02511436492204666\n", - "blocks.0.ln2.hook_normalized 9.506545066833496\n", - "blocks.0.attn.hook_k 0.026024164631962776\n", - "blocks.0.attn.hook_q 1.9935917854309082\n", - "blocks.0.attn.hook_v 0.19012247025966644\n", - "blocks.0.attn.hook_z 2.4572794437408447\n", - "blocks.0.attn.hook_attn_scores 1.9351273775100708\n", - "blocks.0.attn.hook_pattern 1.9483344554901123\n", - "blocks.0.mlp.hook_pre 9.506546020507812\n", - "blocks.0.mlp.hook_post 9.526301383972168\n", - "blocks.0.hook_attn_out 2.4572784900665283\n", - "blocks.0.hook_mlp_out 9.526301383972168\n", - "blocks.0.hook_resid_pre 10.975448608398438\n", - "blocks.0.hook_resid_mid 11.21129035949707\n", - "blocks.0.hook_resid_post 10.834088325500488\n", - "blocks.1.ln1.hook_scale 0.021276870742440224\n", - "blocks.1.ln1.hook_normalized 9.080503463745117\n", - "blocks.1.ln2.hook_scale 0.003745849709957838\n", - "blocks.1.ln2.hook_normalized 4.472580432891846\n", - "blocks.1.attn.hook_k 0.0021377610974013805\n", - "blocks.1.attn.hook_q 0.016889303922653198\n", - "blocks.1.attn.hook_v 9.080748558044434\n", - "blocks.1.attn.hook_z 9.080577850341797\n", - "blocks.1.attn.hook_attn_scores 0.0017388787819072604\n", - "blocks.1.attn.hook_pattern 0.0018037232803180814\n", - "blocks.1.mlp.hook_pre 4.472580432891846\n", - "blocks.1.mlp.hook_post 4.4686713218688965\n", - "blocks.1.hook_attn_out 9.080577850341797\n", - "blocks.1.hook_mlp_out 4.4686713218688965\n", - "blocks.1.hook_resid_pre 10.834088325500488\n", - "blocks.1.hook_resid_mid 10.262277603149414\n", - "blocks.1.hook_resid_post 11.917344093322754\n", - "ln_final.hook_scale 0.009998265653848648\n", - "ln_final.hook_normalized 5.719696521759033\n" - ] - } - ], - "source": [ - "hook_list = list(model.hook_dict.keys())\n", - "losses = []\n", - "loss_labels = []\n", - "for hook_name in hook_list:\n", - " if (\n", - " hook_name in cache\n", - " and hook_name != \"hook_pos_embed\"\n", - " and \"result\" not in hook_name\n", - " ):\n", - " average_act = cache[hook_name].mean(0)\n", - "\n", - " def replacing_with_average_act(activation, hook):\n", - " activation[:] = einops.repeat(\n", - " average_act, \"... -> batch ...\", batch=new_token_batch.size(0)\n", - " )\n", - " return activation\n", - "\n", - " logits = model.run_with_hooks(\n", - " new_token_batch, fwd_hooks=[(hook_name, replacing_with_average_act)]\n", - " )\n", - " loss = loss_fn(logits, new_token_batch)\n", - " print(hook_name, loss.item())\n", - " losses.append(loss.item())\n", - " loss_labels.append(hook_name)" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "line(losses, hover_name=loss_labels)" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized'])" - ] - }, - "execution_count": 61, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cache.cache_dict.keys()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.7.13 ('base')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - }, - "vscode": { - "interpreter": { - "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" - } - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Introduction\n", + "\n", + "The accompanying notebook to my [real-time research](https://www.youtube.com/watch?v=yo4QvDn-vsU) video. Trains a model with no positional embeddings to predict the previous token, and makes a start at analysing what's going on there!\n", + "\n", + "EDIT: The loss spikes were due to the learning rate being max(step/100, 1.0) not min! Thanks to MadHatter for catching that." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import os\n", + "\n", + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "DEVELOPMENT_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "try:\n", + " import google.colab\n", + "\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install einops\n", + " %pip install transformer_lens@v1.15.0\n", + "\n", + " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", + " # # Install another version of node that makes PySvelte work way faster\n", + " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + "\n", + "from transformer_lens import HookedTransformer, HookedTransformerConfig\n", + "import torch\n", + "import numpy as np\n", + "import plotly.express as px\n", + "import plotly.io as pio\n", + "\n", + "pio.renderers.default = \"colab\"\n", + "import tqdm.auto as tqdm\n", + "import einops\n", + "from transformer_lens.utilities import to_numpy\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some plotting code. Wrappers around Plotly, not important to understand." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def line(tensor, line_labels=None, yaxis=\"\", xaxis=\"\", **kwargs):\n", + " tensor = to_numpy(tensor)\n", + " labels = {\"y\": yaxis, \"x\": xaxis}\n", + " fig = px.line(tensor, labels=labels, **kwargs)\n", + " if line_labels:\n", + " for c, label in enumerate(line_labels):\n", + " fig.data[c].name = label\n", + " fig.show()\n", + "\n", + "\n", + "def imshow(tensor, yaxis=\"\", xaxis=\"\", **kwargs):\n", + " tensor = to_numpy(tensor)\n", + " plot_kwargs = {\n", + " \"color_continuous_scale\": \"RdBu\",\n", + " \"color_continuous_midpoint\": 0.0,\n", + " \"labels\": {\"x\": xaxis, \"y\": yaxis},\n", + " }\n", + " plot_kwargs.update(kwargs)\n", + " px.imshow(tensor, **plot_kwargs).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the Model" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "cfg = HookedTransformerConfig(\n", + " n_layers=2,\n", + " d_model=64,\n", + " d_head=64,\n", + " n_heads=1,\n", + " d_mlp=256,\n", + " d_vocab=300,\n", + " n_ctx=50,\n", + " act_fn=\"relu\",\n", + " normalization_type=\"LN\",\n", + " device=device,\n", + ")\n", + "model = HookedTransformer(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "def deactivate_position(model):\n", + " model.pos_embed.W_pos.data[:] = 0.0\n", + " model.pos_embed.W_pos.requires_grad = False\n", + "\n", + "\n", + "deactivate_position(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HookedTransformer(\n", + " (embed): Embed()\n", + " (hook_embed): HookPoint()\n", + " (pos_embed): PosEmbed()\n", + " (hook_pos_embed): HookPoint()\n", + " (blocks): ModuleList(\n", + " (0-1): 2 x TransformerBlock(\n", + " (ln1): LayerNorm(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (ln2): LayerNorm(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (attn): Attention(\n", + " (hook_k): HookPoint()\n", + " (hook_q): HookPoint()\n", + " (hook_v): HookPoint()\n", + " (hook_z): HookPoint()\n", + " (hook_attn_scores): HookPoint()\n", + " (hook_pattern): HookPoint()\n", + " (hook_result): HookPoint()\n", + " )\n", + " (mlp): MLP(\n", + " (hook_pre): HookPoint()\n", + " (hook_post): HookPoint()\n", + " )\n", + " (hook_attn_in): HookPoint()\n", + " (hook_q_input): HookPoint()\n", + " (hook_k_input): HookPoint()\n", + " (hook_v_input): HookPoint()\n", + " (hook_mlp_in): HookPoint()\n", + " (hook_attn_out): HookPoint()\n", + " (hook_mlp_out): HookPoint()\n", + " (hook_resid_pre): HookPoint()\n", + " (hook_resid_mid): HookPoint()\n", + " (hook_resid_post): HookPoint()\n", + " )\n", + " )\n", + " (ln_final): LayerNorm(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (unembed): Unembed()\n", + ")\n" + ] + } + ], + "source": [ + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define data + Loss function" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 0, 93, 34, 155, 274, 116, 114, 248, 68, 3, 298, 83, 194, 20,\n", + " 8, 133, 32, 66, 62, 73, 210, 273, 46, 243, 104, 232, 161, 125,\n", + " 123, 251, 7, 4, 115, 127, 21, 1, 89, 142, 6, 15, 298, 251,\n", + " 88, 229, 108, 114, 23, 88, 3, 265],\n", + " [ 0, 118, 46, 274, 105, 268, 131, 35, 19, 58, 226, 278, 27, 25,\n", + " 276, 180, 164, 4, 95, 27, 74, 201, 105, 65, 80, 185, 44, 258,\n", + " 105, 60, 58, 47, 126, 60, 294, 253, 258, 136, 29, 101, 258, 77,\n", + " 80, 180, 159, 169, 122, 117, 27, 194]])\n" + ] + } + ], + "source": [ + "def make_data_generator(cfg, batch_size, seed=123, incl_bos_token=True):\n", + " torch.manual_seed(seed)\n", + " while True:\n", + " x = torch.randint(1, cfg.d_vocab, (batch_size, cfg.n_ctx))\n", + " if incl_bos_token:\n", + " x[:, 0] = 0\n", + " yield x\n", + "\n", + "\n", + "data_generator = make_data_generator(cfg, 2)\n", + "print(next(data_generator))" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "def loss_fn(logits, tokens, per_token=False):\n", + " # logit shape: [batch, pos, vocab]\n", + " # token shape: [batch, pos]\n", + " logits = logits[:, 1:]\n", + " tokens = tokens[:, :-1]\n", + " log_probs = logits.log_softmax(-1)\n", + " correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]\n", + " if per_token:\n", + " return -correct_log_probs\n", + " else:\n", + " return -correct_log_probs.mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[0.0004, 0.0003, 0.0031, 0.0005]])\n", + "tensor(0.0011)\n" + ] + } + ], + "source": [ + "# Test the loss function works\n", + "test_tokens = torch.arange(5)[None, :]\n", + "test_logits = torch.randn(1, 5, 10)\n", + "test_logits[:, 1, 0] = 10.0\n", + "test_logits[:, 2, 1] = 10.0\n", + "test_logits[:, 3, 2] = 10.0\n", + "test_logits[:, 4, 3] = 10.0\n", + "print(loss_fn(test_logits, test_tokens, per_token=True))\n", + "print(loss_fn(test_logits, test_tokens, per_token=False))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup Optimizer\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 256\n", + "num_epochs = 4000\n", + "lr = 1e-4\n", + "betas = (0.9, 0.95)\n", + "max_grad_norm = 1.0\n", + "wd = 0.1\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)\n", + "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))\n", + "\n", + "data_loader = make_data_generator(cfg, batch_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Training" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "122c183908104b04a600bfe4aca9f009", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/4000 [00:00\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "losses = []\n", + "for epoch in tqdm.tqdm(range(num_epochs)):\n", + " tokens = next(data_loader)\n", + " tokens = tokens.to(device)\n", + " logits = model(tokens)\n", + " loss = loss_fn(logits, tokens)\n", + " loss.backward()\n", + " if max_grad_norm is not None:\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " scheduler.step()\n", + " losses.append(loss.item())\n", + " if epoch % 100 == 0:\n", + " print(f\"Epoch {epoch}: {loss.item()}\")\n", + "px.line(losses, labels={\"x\": \"Epoch\", \"y\": \"Loss\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "# torch.save(model.state_dict(), \"no_pos_experiment_state_dict_v0.pth\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Interpretability" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.)" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.pos_embed.W_pos.norm()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Look at attention patterns" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss: 0.003689224598929286\n" + ] + } + ], + "source": [ + "big_data_loader = make_data_generator(cfg, 4000)\n", + "big_tokens = next(big_data_loader)\n", + "big_tokens = big_tokens.to(device)\n", + "logits, cache = model.run_with_cache(big_tokens)\n", + "print(\"Loss:\", loss_fn(logits, big_tokens).item())" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized']\n" + ] + } + ], + "source": [ + "print(cache)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4000, 1, 50, 50])" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cache[\"blocks.0.attn.hook_pattern\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "batch_index = 0\n", + "tokens = big_tokens[batch_index]\n", + "imshow(\n", + " to_numpy(cache[\"attn\", 0].mean([0, 1])),\n", + " title=\"Layer 0 Attention Pattern\",\n", + " height=500,\n", + " width=500,\n", + ")\n", + "imshow(\n", + " to_numpy(cache[\"attn\", 1].mean([0, 1])),\n", + " title=\"Layer 1 Attention Pattern\",\n", + " height=500,\n", + " width=500,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Look at how different bits of the model directly contribute to the logits" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([5, 4000, 50, 64])\n" + ] + } + ], + "source": [ + "resid_components = [\n", + " cache[\"embed\"],\n", + " cache[\"attn_out\", 0],\n", + " cache[\"mlp_out\", 0],\n", + " cache[\"attn_out\", 1],\n", + " cache[\"mlp_out\", 1],\n", + "]\n", + "labels = [\"embed\", \"A0\", \"M0\", \"A1\", \"M2\"]\n", + "resid_stack = torch.stack(resid_components, 0)\n", + "resid_stack = resid_stack - resid_stack.mean(-1, keepdim=True)\n", + "print(resid_stack.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([5, 50, 300])\n" + ] + } + ], + "source": [ + "fold_W_U = model.ln_final.w[:, None] * model.unembed.W_U\n", + "logit_components = resid_stack[:, batch_index] @ fold_W_U / cache[\"scale\"][batch_index]\n", + "print(logit_components.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "logit_components = logit_components - logit_components.mean(-1, keepdim=True)\n", + "line(\n", + " logit_components[:, torch.arange(1, model.cfg.n_ctx).to(device), tokens[:-1]].T,\n", + " line_labels=labels,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Folding In LayerNorm" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_cfg = HookedTransformerConfig(\n", + " n_layers=2,\n", + " d_model=64,\n", + " d_head=64,\n", + " n_heads=1,\n", + " d_mlp=256,\n", + " d_vocab=300,\n", + " n_ctx=50,\n", + " act_fn=\"relu\",\n", + " normalization_type=\"LNPre\",\n", + " init_weights=False,\n", + ")\n", + "analysis_model = HookedTransformer(analysis_cfg)\n", + "state_dict = model.state_dict()\n", + "analysis_model.load_and_process_state_dict(\n", + " state_dict, fold_ln=True, center_writing_weights=True, center_unembed=True\n", + ")\n", + "deactivate_position(analysis_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "# analysis_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understand Attn 0\n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "QK = model.W_E @ model.W_Q[0, 0] @ model.W_K[0, 0].T @ model.W_E.T\n", + "imshow(QK, yaxis=\"Query\", xaxis=\"Key\")" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "OV = model.W_E @ model.W_V[0, 0] @ model.W_O[0, 0] @ model.W_in[0]\n", + "imshow(OV, yaxis=\"Input Vocab\", xaxis=\"Neuron\")" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "line(OV[:, torch.randint(0, 256, (5,))])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understand MLP 0" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(cache[\"post\", 0][batch_index], yaxis=\"Pos\", xaxis=\"Neuron\")\n", + "imshow(cache[\"post\", 0].mean(0), yaxis=\"Pos\", xaxis=\"Neuron\")\n", + "imshow((cache[\"post\", 0] > 0).float()[batch_index], yaxis=\"Pos\", xaxis=\"Neuron\")\n", + "imshow((cache[\"post\", 0] > 0).float().mean(0), yaxis=\"Pos\", xaxis=\"Neuron\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understand Attn 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understand MLP 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Experiment" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Baseline loss: 0.0036276562605053186\n" + ] + } + ], + "source": [ + "new_token_batch = next(big_data_loader).to(device)\n", + "baseline_loss = loss_fn(model(new_token_batch), new_token_batch).item()\n", + "print(\"Baseline loss:\", baseline_loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hook_embed 10.975448608398438\n", + "blocks.0.ln1.hook_scale 0.4518754482269287\n", + "blocks.0.ln1.hook_normalized 2.4589312076568604\n", + "blocks.0.ln2.hook_scale 0.02511436492204666\n", + "blocks.0.ln2.hook_normalized 9.506545066833496\n", + "blocks.0.attn.hook_k 0.026024164631962776\n", + "blocks.0.attn.hook_q 1.9935917854309082\n", + "blocks.0.attn.hook_v 0.19012247025966644\n", + "blocks.0.attn.hook_z 2.4572794437408447\n", + "blocks.0.attn.hook_attn_scores 1.9351273775100708\n", + "blocks.0.attn.hook_pattern 1.9483344554901123\n", + "blocks.0.mlp.hook_pre 9.506546020507812\n", + "blocks.0.mlp.hook_post 9.526301383972168\n", + "blocks.0.hook_attn_out 2.4572784900665283\n", + "blocks.0.hook_mlp_out 9.526301383972168\n", + "blocks.0.hook_resid_pre 10.975448608398438\n", + "blocks.0.hook_resid_mid 11.21129035949707\n", + "blocks.0.hook_resid_post 10.834088325500488\n", + "blocks.1.ln1.hook_scale 0.021276870742440224\n", + "blocks.1.ln1.hook_normalized 9.080503463745117\n", + "blocks.1.ln2.hook_scale 0.003745849709957838\n", + "blocks.1.ln2.hook_normalized 4.472580432891846\n", + "blocks.1.attn.hook_k 0.0021377610974013805\n", + "blocks.1.attn.hook_q 0.016889303922653198\n", + "blocks.1.attn.hook_v 9.080748558044434\n", + "blocks.1.attn.hook_z 9.080577850341797\n", + "blocks.1.attn.hook_attn_scores 0.0017388787819072604\n", + "blocks.1.attn.hook_pattern 0.0018037232803180814\n", + "blocks.1.mlp.hook_pre 4.472580432891846\n", + "blocks.1.mlp.hook_post 4.4686713218688965\n", + "blocks.1.hook_attn_out 9.080577850341797\n", + "blocks.1.hook_mlp_out 4.4686713218688965\n", + "blocks.1.hook_resid_pre 10.834088325500488\n", + "blocks.1.hook_resid_mid 10.262277603149414\n", + "blocks.1.hook_resid_post 11.917344093322754\n", + "ln_final.hook_scale 0.009998265653848648\n", + "ln_final.hook_normalized 5.719696521759033\n" + ] + } + ], + "source": [ + "hook_list = list(model.hook_dict.keys())\n", + "losses = []\n", + "loss_labels = []\n", + "for hook_name in hook_list:\n", + " if (\n", + " hook_name in cache\n", + " and hook_name != \"hook_pos_embed\"\n", + " and \"result\" not in hook_name\n", + " ):\n", + " average_act = cache[hook_name].mean(0)\n", + "\n", + " def replacing_with_average_act(activation, hook):\n", + " activation[:] = einops.repeat(\n", + " average_act, \"... -> batch ...\", batch=new_token_batch.size(0)\n", + " )\n", + " return activation\n", + "\n", + " logits = model.run_with_hooks(\n", + " new_token_batch, fwd_hooks=[(hook_name, replacing_with_average_act)]\n", + " )\n", + " loss = loss_fn(logits, new_token_batch)\n", + " print(hook_name, loss.item())\n", + " losses.append(loss.item())\n", + " loss_labels.append(hook_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "line(losses, hover_name=loss_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized'])" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cache.cache_dict.keys()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.13 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + }, + "vscode": { + "interpreter": { + "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/demos/Othello_GPT.ipynb b/demos/Othello_GPT.ipynb index d69f1a166..40469f357 100644 --- a/demos/Othello_GPT.ipynb +++ b/demos/Othello_GPT.ipynb @@ -180,7 +180,7 @@ "outputs": [], "source": [ "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import (\n", " HookedRootModule,\n", " HookPoint,\n", @@ -281,7 +281,7 @@ "metadata": {}, "outputs": [], "source": [ - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "\n", "cfg = HookedTransformerConfig(\n", " n_layers=8,\n", diff --git a/demos/SVD_Interpreter_Demo.ipynb b/demos/SVD_Interpreter_Demo.ipynb index 82b85a06e..4a2f93a69 100644 --- a/demos/SVD_Interpreter_Demo.ipynb +++ b/demos/SVD_Interpreter_Demo.ipynb @@ -119,7 +119,7 @@ "import pysvelte\n", "import numpy as np\n", "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens import HookedTransformer, SVDInterpreter" ] }, diff --git a/demos/Santa_Coder.ipynb b/demos/Santa_Coder.ipynb index af5a5b0cb..99f7dcab0 100644 --- a/demos/Santa_Coder.ipynb +++ b/demos/Santa_Coder.ipynb @@ -91,7 +91,7 @@ "# import circuitsvis as cv\n", "\n", "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.model_bridge import TransformerBridge\n", "\n", "_ = torch.set_grad_enabled(False)\n", diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 17d8bc2fd..06e99c511 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -23,8 +23,8 @@ class first, including the examples, and then skimming the available methods. Yo from jaxtyping import Float, Int from typing_extensions import Literal -import transformer_lens.utils as utils -from transformer_lens.utils import Slice, SliceInput, warn_if_mps +import transformer_lens.utilities as utils +from transformer_lens.utilities import Slice, SliceInput, warn_if_mps class ActivationCache: diff --git a/transformer_lens/HookedEncoderDecoder.py b/transformer_lens/HookedEncoderDecoder.py index 89ad258fb..6b753690a 100644 --- a/transformer_lens/HookedEncoderDecoder.py +++ b/transformer_lens/HookedEncoderDecoder.py @@ -37,8 +37,8 @@ from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.utilities import sample_logits, warn_if_mps from transformer_lens.utilities.multi_gpu import get_device_for_block_index -from transformer_lens.utils import sample_logits, warn_if_mps T = TypeVar("T", bound="HookedEncoderDecoder") diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 17816b9c3..bb65f5f5c 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -41,7 +41,7 @@ from typing_extensions import Literal import transformer_lens.loading_from_pretrained as loading -import transformer_lens.utils as utils +import transformer_lens.utilities as utils from transformer_lens.ActivationCache import ActivationCache # Activation cache for run_with_cache; KV cache for generation @@ -63,17 +63,15 @@ from transformer_lens.hook_points import HookedRootModule, HookPoint from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES from transformer_lens.utilities import ( + USE_DEFAULT_VALUE, get_best_available_device, get_device_for_block_index, -) -from transformer_lens.utilities.devices import move_to_and_update_config -from transformer_lens.utils import ( - USE_DEFAULT_VALUE, init_kaiming_normal_, init_kaiming_uniform_, init_xavier_normal_, init_xavier_uniform_, ) +from transformer_lens.utilities.devices import move_to_and_update_config from transformer_lens.weight_processing import ProcessWeights SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 5b744ca0f..470c21231 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -18,8 +18,8 @@ from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookPoint +from transformer_lens.utilities import get_offset_position_ids from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear -from transformer_lens.utils import get_offset_position_ids if is_bitsandbytes_available(): import bitsandbytes as bnb diff --git a/transformer_lens/components/bert_block.py b/transformer_lens/components/bert_block.py index 98fd7c563..8eb0e45ac 100644 --- a/transformer_lens/components/bert_block.py +++ b/transformer_lens/components/bert_block.py @@ -13,7 +13,7 @@ from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.factories.mlp_factory import MLPFactory from transformer_lens.hook_points import HookPoint -from transformer_lens.utils import repeat_along_head_dimension +from transformer_lens.utilities import repeat_along_head_dimension class BertBlock(nn.Module): diff --git a/transformer_lens/components/pos_embed.py b/transformer_lens/components/pos_embed.py index 91d6237e5..e6319e1b3 100644 --- a/transformer_lens/components/pos_embed.py +++ b/transformer_lens/components/pos_embed.py @@ -11,7 +11,7 @@ from jaxtyping import Float, Int from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig -from transformer_lens.utils import get_offset_position_ids +from transformer_lens.utilities import get_offset_position_ids # Positional Embeddings diff --git a/transformer_lens/components/t5_block.py b/transformer_lens/components/t5_block.py index 461ac053d..e83fc7b7f 100644 --- a/transformer_lens/components/t5_block.py +++ b/transformer_lens/components/t5_block.py @@ -11,7 +11,7 @@ from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.factories.mlp_factory import MLPFactory from transformer_lens.hook_points import HookPoint -from transformer_lens.utils import repeat_along_head_dimension +from transformer_lens.utilities import repeat_along_head_dimension class T5Block(nn.Module): diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index 020e654d1..deff95b85 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -24,7 +24,7 @@ from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.factories.mlp_factory import MLPFactory from transformer_lens.hook_points import HookPoint -from transformer_lens.utils import repeat_along_head_dimension +from transformer_lens.utilities import repeat_along_head_dimension # Transformer Block diff --git a/transformer_lens/config/HookedTransformerConfig.py b/transformer_lens/config/HookedTransformerConfig.py index cacf49180..6b3d58bb9 100644 --- a/transformer_lens/config/HookedTransformerConfig.py +++ b/transformer_lens/config/HookedTransformerConfig.py @@ -336,7 +336,7 @@ def __post_init__(self): if self.device is None: self.device = str(get_device()) else: - from transformer_lens.utils import warn_if_mps + from transformer_lens.utilities import warn_if_mps warn_if_mps(self.device) diff --git a/transformer_lens/evals.py b/transformer_lens/evals.py index b4ed61a30..5c318d013 100644 --- a/transformer_lens/evals.py +++ b/transformer_lens/evals.py @@ -14,8 +14,8 @@ from datasets import load_dataset from torch.utils.data import DataLoader, Dataset -from transformer_lens import utils -from transformer_lens.utils import warn_if_mps +from transformer_lens import utilities as utils +from transformer_lens.utilities import warn_if_mps # %% diff --git a/transformer_lens/head_detector.py b/transformer_lens/head_detector.py index fbb50fae8..9efd237ff 100644 --- a/transformer_lens/head_detector.py +++ b/transformer_lens/head_detector.py @@ -13,7 +13,7 @@ from transformer_lens.ActivationCache import ActivationCache from transformer_lens.HookedTransformer import HookedTransformer -from transformer_lens.utils import is_lower_triangular, is_square +from transformer_lens.utilities import is_lower_triangular, is_square HeadName = Literal["previous_token_head", "duplicate_token_head", "induction_head"] HEAD_NAMES = cast(List[HeadName], get_args(HeadName)) diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index a675cb736..a942ec39d 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -32,7 +32,7 @@ from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( BaseTensorConversion, ) -from transformer_lens.utils import Slice, SliceInput, warn_if_mps +from transformer_lens.utilities import Slice, SliceInput, warn_if_mps @dataclass diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 30184ec85..6c703e754 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -24,7 +24,7 @@ Wav2Vec2Model, ) -import transformer_lens.utils as utils +import transformer_lens.utilities as utils from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions import ( convert_apertus_weights, diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 050cf485a..a07a4f330 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -28,7 +28,7 @@ import torch from torch import nn -from transformer_lens import utils +from transformer_lens import utilities as utils from transformer_lens.ActivationCache import ActivationCache from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookPoint diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index ce9907384..0169da4dc 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -26,7 +26,7 @@ ) from transformer_lens.model_bridge.bridge import TransformerBridge from transformer_lens.supported_models import MODEL_ALIASES -from transformer_lens.utils import get_device, get_tokenizer_with_bos +from transformer_lens.utilities import get_device, get_tokenizer_with_bos # Suppress transformers warnings that go to stderr # This prevents notebook tests from failing due to unexpected stderr output diff --git a/transformer_lens/patching.py b/transformer_lens/patching.py index ac95bf441..868b1f215 100644 --- a/transformer_lens/patching.py +++ b/transformer_lens/patching.py @@ -60,7 +60,7 @@ from tqdm.auto import tqdm from typing_extensions import Literal -import transformer_lens.utils as utils +import transformer_lens.utilities as utils from transformer_lens.ActivationCache import ActivationCache from transformer_lens.HookedTransformer import HookedTransformer diff --git a/transformer_lens/train.py b/transformer_lens/train.py index d1eecfc12..24acfd0be 100644 --- a/transformer_lens/train.py +++ b/transformer_lens/train.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm -from transformer_lens import utils +from transformer_lens import utilities as utils from transformer_lens.HookedTransformer import HookedTransformer from transformer_lens.utilities.library_utils import is_library_available diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index d0b14e4d1..d1265f002 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -129,7 +129,7 @@ def move_to_and_update_config( Returns: The model after the operation """ - from transformer_lens.utils import warn_if_mps + from transformer_lens.utilities import warn_if_mps if isinstance(device_or_dtype, torch.device): warn_if_mps(device_or_dtype) diff --git a/transformer_lens/utilities/exploratory_utils.py b/transformer_lens/utilities/exploratory_utils.py index 6cb2eb992..426aee6a0 100644 --- a/transformer_lens/utilities/exploratory_utils.py +++ b/transformer_lens/utilities/exploratory_utils.py @@ -31,13 +31,13 @@ def test_prompt( Examples: - >>> from transformer_lens import HookedTransformer, utils + >>> from transformer_lens import HookedTransformer, utilities >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") Loaded pretrained model tiny-stories-1M into HookedTransformer >>> prompt = "Why did the elephant cross the" >>> answer = "road" - >>> utils.test_prompt(prompt, answer, model) + >>> utilities.test_prompt(prompt, answer, model) Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the'] Tokenized answer: [' road'] Performance on answer token: diff --git a/transformer_lens/weight_processing.py b/transformer_lens/weight_processing.py index 1d219973a..d9de9b481 100644 --- a/transformer_lens/weight_processing.py +++ b/transformer_lens/weight_processing.py @@ -11,7 +11,7 @@ import einops import torch -import transformer_lens.utils as utils +import transformer_lens.utilities as utils from transformer_lens.config.TransformerLensConfig import TransformerLensConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter From 9ef4e4ce5190c8905ce2ee53790f8c40535fe58d Mon Sep 17 00:00:00 2001 From: davidcyze Date: Mon, 20 Apr 2026 14:06:02 -0500 Subject: [PATCH 02/21] fix: use cfg.dtype instead of torch.get_default_dtype for KV cache init (#1260) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: use cfg.dtype instead of torch.get_default_dtype for KV cache init TransformerLensKeyValueCacheEntry.init_cache_entry initialised past_keys and past_values with torch.get_default_dtype(), which is torch.float32 unless the caller has explicitly overridden the global default. When a model runs in float16 or bfloat16, the subsequent torch.cat([past_keys, new_keys], dim=1) inside append() promoted the result to float32. Downstream attention-score computation then failed with: RuntimeError: expected scalar type Half but found Float at AbstractAttention.calculate_attention_scores (q_ @ k_ / attn_scale). This blocked generate() with use_past_kv_cache=True (the default) for any reduced-precision model. Disabling the KV cache worked but turned generation into O(seq_len^2) per step, which is prohibitive for any practical use. The fix uses cfg.dtype — the same dtype the rest of the model is loaded with. This is what every production fp16 inference stack does (HuggingFace transformers, vLLM, TGI, llama.cpp, TensorRT-LLM). Added tests/unit/test_key_value_cache_entry.py covering: - init_cache_entry respects cfg.dtype for fp32, fp16, bfloat16 - behaviour is independent of torch.get_default_dtype() - append() preserves cfg.dtype without promoting to fp32 - grouped-query-attention path uses n_key_value_heads correctly * isort imports --------- Co-authored-by: jlarson4 --- tests/unit/test_key_value_cache_entry.py | 106 ++++++++++++++++++ .../cache/key_value_cache_entry.py | 10 +- 2 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_key_value_cache_entry.py diff --git a/tests/unit/test_key_value_cache_entry.py b/tests/unit/test_key_value_cache_entry.py new file mode 100644 index 000000000..b7806454a --- /dev/null +++ b/tests/unit/test_key_value_cache_entry.py @@ -0,0 +1,106 @@ +"""Tests for TransformerLensKeyValueCacheEntry.init_cache_entry dtype behaviour. + +The buggy pre-fix code used ``torch.get_default_dtype()`` to initialise +``past_keys`` and ``past_values``. PyTorch's default is ``torch.float32``, +so the bug silently produced the correct dtype for fp32 models but the +wrong dtype (fp32 instead of fp16/bf16) for reduced-precision ones. Of +the tests below, ``test_init_cache_entry_uses_cfg_dtype_float32`` is +therefore a baseline sanity check that passes against both the buggy +and fixed code — it verifies the common case works, not that the bug is +absent. The real regression guards are +``test_init_cache_entry_uses_cfg_dtype_float16``, +``..._bfloat16``, ``..._dtype_independent_of_global_default``, and +``test_append_preserves_cfg_dtype``, which all fail against the buggy +code (the fp16 cache was getting promoted to fp32 by the bug, breaking +the downstream attention-score matmul). +""" + +import torch + +from transformer_lens.cache.key_value_cache_entry import ( + TransformerLensKeyValueCacheEntry, +) +from transformer_lens.config.TransformerLensConfig import TransformerLensConfig + + +def _make_cfg(dtype: torch.dtype, n_heads: int = 4, d_head: int = 8, n_key_value_heads=None): + return TransformerLensConfig( + d_model=n_heads * d_head, + d_head=d_head, + n_layers=1, + n_ctx=32, + n_heads=n_heads, + n_key_value_heads=n_key_value_heads, + dtype=dtype, + ) + + +def test_init_cache_entry_uses_cfg_dtype_float32(): + """Baseline: cfg.dtype=float32 produces fp32 buffers. + + Note: this test passes against both the buggy and fixed implementations + because torch's default dtype is also float32. It is a sanity check + that the common case works, not a regression guard for the specific + bug this module was added to prevent. See module docstring and + ``test_init_cache_entry_dtype_independent_of_global_default`` for the + tests that discriminate fix vs bug. + """ + cfg = _make_cfg(dtype=torch.float32) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu") + assert entry.past_keys.dtype == torch.float32 + assert entry.past_values.dtype == torch.float32 + + +def test_init_cache_entry_uses_cfg_dtype_float16(): + cfg = _make_cfg(dtype=torch.float16) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu") + assert entry.past_keys.dtype == torch.float16 + assert entry.past_values.dtype == torch.float16 + + +def test_init_cache_entry_uses_cfg_dtype_bfloat16(): + cfg = _make_cfg(dtype=torch.bfloat16) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu") + assert entry.past_keys.dtype == torch.bfloat16 + assert entry.past_values.dtype == torch.bfloat16 + + +def test_init_cache_entry_dtype_independent_of_global_default(): + """Regression guard: cache dtype follows cfg.dtype, not the global default. + + Also covers the fp32 case indirectly: if someone reintroduces the old + ``torch.get_default_dtype()`` behaviour, this test plus the fp16 / + bfloat16 / append / GQA tests catch it; the fp32-only baseline above + would not, since fp32 happens to be torch's global default. + """ + cfg = _make_cfg(dtype=torch.float16) + original_default = torch.get_default_dtype() + try: + torch.set_default_dtype(torch.float32) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu") + assert entry.past_keys.dtype == torch.float16 + assert entry.past_values.dtype == torch.float16 + finally: + torch.set_default_dtype(original_default) + + +def test_append_preserves_cfg_dtype(): + """After append, past_keys stays in cfg.dtype — no float promotion.""" + cfg = _make_cfg(dtype=torch.float16) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu") + new_keys = torch.randn(1, 3, cfg.n_heads, cfg.d_head, dtype=torch.float16) + new_values = torch.randn(1, 3, cfg.n_heads, cfg.d_head, dtype=torch.float16) + updated_keys, updated_values = entry.append(new_keys, new_values) + assert updated_keys.dtype == torch.float16 + assert updated_values.dtype == torch.float16 + assert entry.past_keys.dtype == torch.float16 + assert entry.past_values.dtype == torch.float16 + + +def test_init_cache_entry_handles_grouped_query_attention(): + """When n_key_value_heads is set (GQA), it should be used instead of n_heads.""" + cfg = _make_cfg(dtype=torch.float16, n_heads=32, d_head=128, n_key_value_heads=8) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu", batch_size=2) + assert entry.past_keys.shape == (2, 0, 8, 128) + assert entry.past_values.shape == (2, 0, 8, 128) + assert entry.past_keys.dtype == torch.float16 diff --git a/transformer_lens/cache/key_value_cache_entry.py b/transformer_lens/cache/key_value_cache_entry.py index eec478f7e..b8c8c57a8 100644 --- a/transformer_lens/cache/key_value_cache_entry.py +++ b/transformer_lens/cache/key_value_cache_entry.py @@ -27,12 +27,18 @@ def init_cache_entry( batch_size: int = 1, ): n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads + # Use cfg.dtype so the cache matches the model's dtype. Using + # torch.get_default_dtype() (which is float32 unless the caller has + # set it) caused the subsequent torch.cat([past_keys, new_keys]) to + # promote the result to float32 when the model runs in float16 or + # bfloat16, which in turn broke the attention-score matmul with + # "expected scalar type Half but found Float". return cls( past_keys=torch.empty( - (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=torch.get_default_dtype() + (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype ), past_values=torch.empty( - (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=torch.get_default_dtype() + (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype ), ) From c1e5d4b0c73c8619ae4c51ca12e41b3827f41ea2 Mon Sep 17 00:00:00 2001 From: Brendan Long Date: Mon, 20 Apr 2026 13:44:02 -0700 Subject: [PATCH 03/21] Fix tests broken by a local GPU (#1219) * Fix apertus test failing on machines with GPU Tensor equality includes the device, so set device="cpu" so weight tensors always match expected, even if there's GPU they could be created on. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix test_cuda using nonexistent mlm_tokens fixture The test_cuda function referenced a fixture named mlm_tokens which was never defined, causing a fixture-not-found error. Changed to use the existing tokens fixture which provides the same MLM-style tokenized input. Co-Authored-By: Claude Opus 4.6 (1M context) * Resolve conflict --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Jonah Larson --- tests/unit/pretrained_weight_conversions/test_apertus.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/pretrained_weight_conversions/test_apertus.py b/tests/unit/pretrained_weight_conversions/test_apertus.py index d7e5760e1..68ba5089a 100644 --- a/tests/unit/pretrained_weight_conversions/test_apertus.py +++ b/tests/unit/pretrained_weight_conversions/test_apertus.py @@ -27,6 +27,7 @@ def make_cfg(use_qk_norm=True, n_key_value_heads=4): use_qk_norm=use_qk_norm, n_key_value_heads=n_key_value_heads, dtype=torch.float32, + device="cpu", ) @@ -183,7 +184,7 @@ def test_zero_biases_have_correct_device(self): "blocks.0.mlp.b_out", "unembed.b_U", ]: - assert sd[key].device.type == str(cfg.device), f"{key} on wrong device" + assert sd[key].device.type == cfg.device, f"{key} on wrong device" def test_unembed_shapes(self): cfg = make_cfg() From c67a0a15f1f132e4f94e0c9a6fe547e48ab07374 Mon Sep 17 00:00:00 2001 From: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Date: Tue, 21 Apr 2026 03:44:28 +0530 Subject: [PATCH 04/21] fix: handle LayerNorm folding correctly in load_and_process_state_dict (#1215) * fix: handle LayerNorm folding correctly in load_and_process_state_dict Previously, calling load_and_process_state_dict(state_dict, fold_ln=True) had two failure modes: 1. If the state_dict had unfolded LN weights, fold_layer_norm removed the LN keys but the model's modules were not replaced with LNPre, leaving mismatched architecture and broken hooks. 2. If the state_dict was already folded (no LN keys), fold_layer_norm crashed with a KeyError trying to access missing LN weight keys. Fix both by: - Checking whether LN keys exist before attempting to fold (skip with warning if already folded) - Replacing LN/RMS modules with LNPre/RMSPre before folding, matching the logic previously only in process_weights_ - Calling self.setup() after loading to re-attach hooks - Simplifying process_weights_ to delegate fully to the fixed method Fixes #219 Signed-off-by: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> * style: fix black formatting in HookedTransformer.py * fix: handle LayerNorm folding correctly in load_and_process_state_dict * Make sure not to double fold --------- Signed-off-by: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Co-authored-by: jlarson4 --- transformer_lens/HookedTransformer.py | 68 +++++++++++++++++---------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index bb65f5f5c..7a5001948 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1656,8 +1656,48 @@ def load_and_process_state_dict( ) state_dict = self.fill_missing_keys(state_dict) + if fold_ln: + if self.cfg.num_experts and self.cfg.num_experts > 1: + logging.warning( + "You are using MoE, so the layer norm weights can't be folded! Skipping" + ) + fold_ln = False + elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: + logging.warning( + "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" + ) + fold_ln = False + else: + ln_keys_present = any( + k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict + ) + if not ln_keys_present: + logging.warning( + "fold_ln=True but no LayerNorm weights found in state_dict. " + "The model may have been saved with already-folded LayerNorms. " + "Skipping fold." + ) + fold_ln = False + else: + if self.cfg.normalization_type == "LN": + self.cfg.normalization_type = "LNPre" + self.ln_final = LayerNormPre(self.cfg) + for layer in self.blocks: + layer.ln1 = LayerNormPre(self.cfg) + layer.ln2 = LayerNormPre(self.cfg) + if self.cfg.is_layer_norm_activation(): + layer.mlp.ln = LayerNormPre(self.cfg) + elif self.cfg.normalization_type == "RMS": + self.cfg.normalization_type = "RMSPre" + self.ln_final = RMSNormPre(self.cfg) + for layer in self.blocks: + layer.ln1 = RMSNormPre(self.cfg) + layer.ln2 = RMSNormPre(self.cfg) + if self.cfg.is_layer_norm_activation(): + layer.mlp.ln = RMSNormPre(self.cfg) # Use the centralized ProcessWeights class for all weight processing + # (fold_ln is passed through — if we skipped above, it's now False) state_dict = ProcessWeights.process_weights( state_dict, self.cfg, @@ -1678,6 +1718,9 @@ def load_and_process_state_dict( self.load_state_dict({key: state_dict[key]}, strict=False) del state_dict[key] + if fold_ln: + self.setup() + def fill_missing_keys(self, state_dict): return loading.fill_missing_keys(self, state_dict) @@ -1817,31 +1860,6 @@ def process_weights_( version of the same model. """ state_dict = self.state_dict() - if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: - # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing - # A warning is already issued in `load_and_process_state_dict` - pass - elif fold_ln and self.cfg.normalization_type == "LN": - # If we're folding the LN into the weights, we need to replace all the layernorm layers - # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky, - # but it's the easiest way to do it. - self.cfg.normalization_type = "LNPre" - self.ln_final = LayerNormPre(self.cfg) - for layer in self._get_blocks(): - layer.ln1 = LayerNormPre(self.cfg) - layer.ln2 = LayerNormPre(self.cfg) - if self.cfg.is_layer_norm_activation(): - layer.mlp.ln = LayerNormPre(self.cfg) - elif fold_ln and self.cfg.normalization_type == "RMS": - # We do the same for RMSNorm if used - self.cfg.normalization_type = "RMSPre" - self.ln_final = RMSNormPre(self.cfg) - for layer in self._get_blocks(): - layer.ln1 = RMSNormPre(self.cfg) - layer.ln2 = RMSNormPre(self.cfg) - if self.cfg.is_layer_norm_activation(): - layer.mlp.ln = RMSNormPre(self.cfg) - self.load_and_process_state_dict( state_dict, fold_ln=fold_ln, From 524bca931ef5ff8e74c69c3ce393fbc5e0f031c0 Mon Sep 17 00:00:00 2001 From: Brendan Long Date: Mon, 20 Apr 2026 16:00:08 -0700 Subject: [PATCH 05/21] Fix HookedTransformerConfig rotary_base types (#1231) rotary_base is frequently set to floats in the code but was typed as an int, causing beartype errors if the configs get loaded in a test: https://github.com/TransformerLensOrg/TransformerLens/blob/9c5a2a81674d5bcefa641c816b66e9827ccdf637/transformer_lens/loading_from_pretrained.py#L1984 HF confgs' allegedly always have rope_theta as a float: https://github.com/huggingface/transformers/blob/c38b2fb78eaedd4261a0e446f7976345cd1c7f1b/src/transformers/modeling_rope_utils.py#L645 But sometimes they're actually ints, and beartype doesn't consider int to be a subtype of float: https://github.com/beartype/beartype/issues/66 This updates the type to Union[float, int] to be accurate while keeping beartype happy. Co-authored-by: jlarson4 --- transformer_lens/components/abstract_attention.py | 2 +- transformer_lens/config/HookedTransformerConfig.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 470c21231..8cd51ea91 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -600,7 +600,7 @@ def calculate_sin_cos_rotary( self, rotary_dim: int, n_ctx: int, - base: int = 10000, + base: Union[float, int] = 10000, dtype: torch.dtype = torch.float32, ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: """ diff --git a/transformer_lens/config/HookedTransformerConfig.py b/transformer_lens/config/HookedTransformerConfig.py index 6b3d58bb9..e8450df1f 100644 --- a/transformer_lens/config/HookedTransformerConfig.py +++ b/transformer_lens/config/HookedTransformerConfig.py @@ -206,7 +206,7 @@ class HookedTransformerConfig(TransformerLensConfig): YARN extension. Defaults to 4096. use_qk_norm (bool): Whether to apply RMSNorm to the query and key projections before computing attention scores. Used by Gemma 3 models. Defaults to False. - rotary_base_local (int, *optional*): The base for rotary positional embeddings in local + rotary_base_local (float, *optional*): The base for rotary positional embeddings in local attention layers. Used by models with hybrid local/global attention (e.g., Gemma 3) which use different RoPE bases for local (10k) and global (1M) attention. Defaults to None, which means the standard rotary_base is used for all layers. @@ -250,9 +250,9 @@ class HookedTransformerConfig(TransformerLensConfig): dtype: torch.dtype = torch.float32 tokenizer_prepends_bos: Optional[bool] = None post_embedding_ln: bool = False - rotary_base: int = 10000 + rotary_base: Union[float, int] = 10000 rotary_base_local: Optional[ - int + Union[float, int] ] = None # For models with different RoPE bases per attention type (e.g., Gemma 3) rotary_scaling_factor: float = ( 1.0 # Linear RoPE scaling factor for global attention (e.g., 8.0 for Gemma 3 4B) From bd67b0f5053ad215164b60553703f04e798df3de Mon Sep 17 00:00:00 2001 From: Tuomas Oikarinen <33813139+tuomaso@users.noreply.github.com> Date: Tue, 21 Apr 2026 12:30:55 -0700 Subject: [PATCH 06/21] Fixed Masking in HookedTransformer.generate (#999) * fixed batching in generate * added test case * Move & improve tests * make check format and mypy * fix mypy errors * Stop jaxtyping failures * Updated to also fix TransformerBridge for the same issue --------- Co-authored-by: jlarson4 --- tests/acceptance/conftest.py | 15 ++++++ tests/acceptance/test_generate_batch.py | 30 ++++++++++++ transformer_lens/HookedTransformer.py | 63 ++++++++++++++++++++---- transformer_lens/model_bridge/bridge.py | 64 ++++++++++++++++++------- 4 files changed, 148 insertions(+), 24 deletions(-) create mode 100644 tests/acceptance/conftest.py create mode 100644 tests/acceptance/test_generate_batch.py diff --git a/tests/acceptance/conftest.py b/tests/acceptance/conftest.py new file mode 100644 index 000000000..50ad394a6 --- /dev/null +++ b/tests/acceptance/conftest.py @@ -0,0 +1,15 @@ +"""Shared fixtures for acceptance tests. + +Session-scoped fixtures avoid redundant model loads across test files. +All models used here must be in the CI cache (see .github/workflows/checks.yml). +""" + +import pytest + + +@pytest.fixture(scope="session") +def gpt2_model(): + """Session-scoped HookedTransformer gpt2 with default weight processing.""" + from transformer_lens import HookedTransformer + + return HookedTransformer.from_pretrained("gpt2", device="cpu") diff --git a/tests/acceptance/test_generate_batch.py b/tests/acceptance/test_generate_batch.py new file mode 100644 index 000000000..8b333d5f7 --- /dev/null +++ b/tests/acceptance/test_generate_batch.py @@ -0,0 +1,30 @@ +"""Tests that batched HookedTransformer generation matches individual generation.""" + + +def test_ht_generate_batch_matches_individual(gpt2_model): + """Batched generate() should match one-by-one generate() for left-padded inputs.""" + prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"] + individual_outputs = [gpt2_model.generate(p, verbose=False, do_sample=False) for p in prompts] + + batched_outputs = gpt2_model.generate(prompts, verbose=False, do_sample=False) + for i, prompt in enumerate(prompts): + assert ( + individual_outputs[i] == batched_outputs[i] + ), f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}" + + +def test_ht_generate_batch_without_kv_cache(gpt2_model): + """Same test with use_past_kv_cache=False.""" + prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"] + individual_outputs = [ + gpt2_model.generate(p, verbose=False, do_sample=False, use_past_kv_cache=False) + for p in prompts + ] + + batched_outputs = gpt2_model.generate( + prompts, verbose=False, do_sample=False, use_past_kv_cache=False + ) + for i, prompt in enumerate(prompts): + assert ( + individual_outputs[i] == batched_outputs[i] + ), f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}" diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 7a5001948..eaff70094 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1936,8 +1936,9 @@ def generate( implying usage of self.cfg.default_prepend_bos (default is True unless specified otherwise). Pass True or False to override the default. padding_side (Union[Literal["left", "right"], None], optional): Overrides - self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple - strings of different lengths. + self.tokenizer.padding_side. Specifies which side to pad when tokenizing + multiple strings of different lengths. For batched list inputs, left-padding + is forced internally for correct generation behavior. return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the input was ('input'). @@ -1974,13 +1975,25 @@ def generate( else: return_type = "embeds" + # initial_attention_mask is always computed so that single-prompt and + # batched generation go through the same masked code path, producing + # consistent results for the same prompt regardless of batching. + initial_attention_mask: Optional[torch.Tensor] = None + _is_batched_list = isinstance(input, list) and len(input) > 1 + if isinstance(input, (str, list)): input_type = "str" - # If text, convert to tokens (batch_size=1) assert ( self.tokenizer is not None ), "Must provide a tokenizer if passing a string to the model" - input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) + if _is_batched_list: + # Force left-padding for batched generation so real tokens + # are flush-right and logits[:, -1, :] is always correct. + input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side="left") + else: + input = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + ) elif input.ndim == 2: input_type = "tokens" else: @@ -1988,6 +2001,27 @@ def generate( input_tokens = input if input_type in ["str", "tokens"] else None batch_size, ctx_length = input.shape[0], input.shape[1] + + # Compute initial attention mask. For batched inputs with padding, + # this correctly masks pad tokens. For single/unpadded inputs, this + # is all-ones which matches the no-mask code path but ensures both + # go through the same PosEmbed/attention logic for consistency. + if input_tokens is not None and self.tokenizer is not None: + _prepend_bos = ( + self.cfg.default_prepend_bos + if prepend_bos is USE_DEFAULT_VALUE + else (False if prepend_bos is None else prepend_bos) + ) + # Temporarily set padding_side="left" so get_attention_mask + # scans for leading pads (matching the left-padded tokens). + _orig_padding_side = self.tokenizer.padding_side + if _is_batched_list: + self.tokenizer.padding_side = "left" + initial_attention_mask = utils.get_attention_mask( + self.tokenizer, input_tokens, _prepend_bos + ) + if _is_batched_list: + self.tokenizer.padding_side = _orig_padding_side device = get_device_for_block_index(0, self.cfg) input = input.to(device) if use_past_kv_cache: @@ -2062,10 +2096,20 @@ def generate( for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): pos_offset = self.get_pos_offset(past_kv_cache, batch_size) - tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int) - attention_mask = utils.get_attention_mask( - self.tokenizer, tokens, False if prepend_bos is None else prepend_bos - ).to(device) + # Extend the initial attention mask with 1s for generated tokens. + attention_mask: Optional[torch.Tensor] = None + if initial_attention_mask is not None: + n_new = len(sampled_tokens_list) + if n_new > 0: + ones = torch.ones( + batch_size, + n_new, + dtype=initial_attention_mask.dtype, + device=device, + ) + attention_mask = torch.cat([initial_attention_mask.to(device), ones], dim=1) + else: + attention_mask = initial_attention_mask.to(device) residual, shortformer_pos_embed = self.get_residual( embeds, pos_offset, @@ -2089,6 +2133,7 @@ def generate( past_kv_cache=past_kv_cache, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) else: logits = self.forward( @@ -2099,6 +2144,7 @@ def generate( past_kv_cache=past_kv_cache, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) else: # We input the entire sequence, as a [batch, pos] tensor, since we aren't using @@ -2110,6 +2156,7 @@ def generate( padding_side=padding_side, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) final_logits = logits[:, -1, :] diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index a07a4f330..93778e014 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -2121,9 +2121,9 @@ def generate( prepend_bos: Accepted for API compatibility but not applied during generation. The HF model expects tokens in its native format (tokenizer defaults). Overriding BOS can silently degrade generation quality. - padding_side: Accepted for API compatibility but not applied during generation. - The generation loop always extends tokens to the right, so overriding - initial padding_side creates inconsistent token layout. + padding_side: Which side to pad when tokenizing multiple strings of different + lengths. For batched list inputs, left-padding is forced internally for + correct generation behavior. Defaults to None (tokenizer default). return_type: The type of output to return - 'input', 'str', or 'tokens' verbose: Not used in Bridge (kept for API compatibility) output_logits: If True, return a ModelOutput with sequences and logits tuple @@ -2135,10 +2135,9 @@ def generate( Generated sequence as string, list of strings, or tensor depending on input type and return_type. If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. """ - # prepend_bos and padding_side are intentionally not applied during generation. + # prepend_bos is intentionally not applied during generation. # The HF model expects tokens in its native format. Overriding BOS can silently - # degrade quality, and overriding padding_side conflicts with the generation loop - # which always extends tokens to the right. + # degrade quality. if prepend_bos is not None: import warnings @@ -2149,27 +2148,28 @@ def generate( "resulting tensor to generate().", stacklevel=2, ) - if padding_side is not None: - import warnings - - warnings.warn( - "padding_side is ignored during TransformerBridge.generate(). " - "The generation loop extends tokens to the right regardless of initial " - "padding. To control padding, tokenize with to_tokens(padding_side=...) " - "and pass the resulting tensor to generate().", - stacklevel=2, - ) + # padding_side is handled internally: for batched list inputs, left-padding + # is forced to ensure correct generation. See _is_batched_list logic below. # Stateful dispatch is decided after input parsing so we can fall back # to hf_generate() for input types the stateful loop doesn't handle. is_stateful_model = getattr(self.cfg, "is_stateful", False) + _is_batched_list = isinstance(input, list) and len(input) > 1 + _generate_from_embeds = False if isinstance(input, str): input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) input_type = "str" elif isinstance(input, list): + # Force left-padding for batched generation so real tokens are + # flush-right and logits[:, -1, :] is always the last real token. + if _is_batched_list: + _orig_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + if _is_batched_list: + self.tokenizer.padding_side = _orig_padding_side input_type = "list" elif isinstance(input, torch.Tensor) and input.is_floating_point(): # inputs_embeds: pre-computed embeddings (e.g., from multimodal models) @@ -2307,6 +2307,30 @@ def generate( ) else: forward_kwargs: Dict[str, Any] = {} + # Compute attention mask and position_ids for batched + # inputs with padding. HF models default to all-ones + # when no mask is given, which ignores padding tokens. + if ( + _is_batched_list + and self.tokenizer is not None + and self.tokenizer.pad_token_id is not None + ): + # Temp-swap to "left" so get_attention_mask scans + # for leading pads (matching the left-padded tokens). + _prev_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + attn_mask = utils.get_attention_mask( + self.tokenizer, + current_tokens, + prepend_bos=getattr(self.cfg, "default_prepend_bos", True), + ).to(self.cfg.device) + self.tokenizer.padding_side = _prev_side + forward_kwargs["attention_mask"] = attn_mask + # Adjust position_ids for left-padding so pad + # tokens don't consume real position embeddings. + position_ids = attn_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attn_mask == 0, 1) + forward_kwargs["position_ids"] = position_ids # Pass multimodal inputs only on the first step — the vision # encoder processes the image once, embedding it into the # token sequence. This includes pixel_values plus any extra @@ -2346,6 +2370,10 @@ def generate( [input_seq_pos], device=self.cfg.device ) forward_kwargs["cache_position"] = cache_position + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] logits = self( current_tokens[:, -1:], return_type="logits", @@ -2356,6 +2384,10 @@ def generate( if _hf_kv_cache is not None: # Cached step: pass only the last token + cache forward_kwargs["past_key_values"] = _hf_kv_cache + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] logits = self( current_tokens[:, -1:], return_type="logits", From 5fe490e10a639d1ba2b3f1eea07e2bf4b64ecf65 Mon Sep 17 00:00:00 2001 From: Anthony Duong <42191920+anthonyduong9@users.noreply.github.com> Date: Wed, 22 Apr 2026 02:44:33 +0700 Subject: [PATCH 07/21] Add hooked transformer generate stream (#908) * adds HookedTransformer.generate_stream() * fixes mypy errors * Adjusted for TransformerLens 3 changes --------- Co-authored-by: Bryce Meyer Co-authored-by: jlarson4 --- tests/acceptance/test_hooked_transformer.py | 32 +++ transformer_lens/HookedEncoderDecoder.py | 4 +- transformer_lens/HookedTransformer.py | 227 ++++++++++++++++++++ 3 files changed, 261 insertions(+), 2 deletions(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 42fcddae3..e830c5bd7 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -236,6 +236,38 @@ def test_bloom_similarity_with_hf_model_with_kv_cache_activated(): assert output_tf == output_hf_str +def test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream(): + tf_model = HookedTransformer.from_pretrained( + "bigscience/bloom-560m", default_prepend_bos=False, device="cpu" + ) + + hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") + hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + + final_output = "" + for result in tf_model.generate_stream( + text, + do_sample=False, + use_past_kv_cache=True, + verbose=False, + max_new_tokens=10, + max_tokens_per_yield=10, + ): + final_output += tf_model.to_string(result[0]) + + hf_input_ids = hf_tokenizer(text, return_tensors="pt").input_ids + output_hf_tokens = hf_model.generate( + hf_input_ids, + do_sample=False, + max_new_tokens=10, + ) + output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True) + + assert ( + final_output == output_hf_str + ), f"\nStreaming output: {final_output}\nHF output: {output_hf_str}" + + def check_norm_folding( model_name, hf_model=None, diff --git a/transformer_lens/HookedEncoderDecoder.py b/transformer_lens/HookedEncoderDecoder.py index 6b753690a..e683d2f91 100644 --- a/transformer_lens/HookedEncoderDecoder.py +++ b/transformer_lens/HookedEncoderDecoder.py @@ -484,13 +484,13 @@ def generate( else: return decoder_input - @overload + @overload # type: ignore[overload-overlap] def run_with_cache( self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: ... - @overload + @overload # type: ignore[overload-overlap] def run_with_cache( self, *model_args: Any, return_cache_object: Literal[False] = False, **kwargs: Any ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index eaff70094..6baa46420 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -13,6 +13,7 @@ import logging import os +from collections.abc import Generator from typing import ( Any, Dict, @@ -2250,6 +2251,232 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ... else: return result + @torch.inference_mode() + def generate_stream( + self, + input: Union[str, Float[torch.Tensor, "batch pos"]] = "", + max_new_tokens: int = 10, + max_tokens_per_yield: int = 25, + stop_at_eos: bool = True, + eos_token_id: Optional[int] = None, + do_sample: bool = True, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: float = 1.0, + freq_penalty: float = 0.0, + use_past_kv_cache: bool = True, + prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, + padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, + return_type: Optional[str] = "input", + verbose: bool = True, + ) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]: + """Stream tokens from the Model as they are generated. + + Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached, + yielding batches of tokens progressively during generation rather than waiting for the entire + sequence to be generated. + + To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish + (by producing an EOT token), we keep running the model on the entire batch, but throw away + the output for a finished sequence and just keep adding EOTs to pad. + + This supports entering a single string, but not a list of strings - if the strings don't + tokenize to exactly the same length, this gets messy. If that functionality is needed, + convert them to a batch of tokens and input that instead. + + Args: + input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, + pos]) or a text string (this will be converted to a batch of tokens with batch size + 1). + max_new_tokens (int): Maximum number of tokens to generate. + max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding. + Controls how frequently the function yields tokens during generation. + stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. + eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end + of sentence. If None, use the tokenizer's eos_token_id - required if using + stop_at_eos. It's also possible to provide a list of token IDs (not just the + eos_token_id), in which case the generation will stop when any of them are output + (useful e.g. for stable_lm). + do_sample (bool): If True, sample from the model's output distribution. Otherwise, use + greedy search (take the max logit each time). + top_k (int): Number of tokens to sample from. If None, sample from all tokens. + top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, + we take the top tokens with cumulative probability >= top_p. + temperature (float): Temperature for sampling. Higher values will make the model more + random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is + sampling from a uniform distribution). + freq_penalty (float): Frequency penalty for sampling - how much to penalise previous + tokens. Higher values will make the model more random. + use_past_kv_cache (bool): If True, create and use cache to speed up generation. + prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend + the BOS token to the input (applicable when input is a string). Defaults to None, + implying usage of self.cfg.default_prepend_bos (default is True unless specified + otherwise). Pass True or False to override the default. + padding_side (Union[Literal["left", "right"], None], optional): Overrides + self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple + strings of different lengths. + return_type (Optional[str]): The type of the output to return - either a string (str), + a tensor of tokens (tensor) or whatever the format of the input was (input). + verbose (bool): If True, show tqdm progress bars for generation. + + Yields: + outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded + progressively during generation. Each yield contains accumulated tokens since the last + yield, up to max_tokens_per_yield. + """ + + with utils.LocallyOverridenDefaults( + self, prepend_bos=prepend_bos, padding_side=padding_side + ): + if type(input) == str: + # If text, convert to tokens (batch_size=1) + assert ( + self.tokenizer is not None + ), "Must provide a tokenizer if passing a string to the model" + tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) + else: + assert isinstance(input, torch.Tensor), "Input must be a tensor when not a string" + tokens = input + + if return_type == "input": + if type(input) == str: + return_type = "str" + else: + return_type = "tensor" + + assert isinstance(tokens, torch.Tensor) + batch_size, ctx_length = tokens.shape + device = get_device_for_block_index(0, self.cfg) + tokens = tokens.to(device) + if use_past_kv_cache: + past_kv_cache = TransformerLensKeyValueCache.init_cache( + self.cfg, self.cfg.device, batch_size + ) + else: + past_kv_cache = None + + stop_tokens: List[int] = [] + eos_token_for_padding = 0 + assert self.tokenizer is not None + if stop_at_eos: + tokenizer_has_eos_token = ( + self.tokenizer is not None and self.tokenizer.eos_token_id is not None + ) + if eos_token_id is None: + assert ( + tokenizer_has_eos_token + ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" + + eos_token_id = self.tokenizer.eos_token_id + + if isinstance(eos_token_id, int): + stop_tokens = [eos_token_id] + eos_token_for_padding = eos_token_id + else: + # eos_token_id is a Sequence (e.g. list or tuple) + stop_tokens = eos_token_id + eos_token_for_padding = ( + self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] + ) + + # An array to track which sequences in the batch have finished. + finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) + + accumulated_tokens: Optional[torch.Tensor] = None + tokens_since_last_yield = 0 + + # Currently nothing in HookedTransformer changes with eval, but this is here in case + # that changes in the future. + self.eval() + for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): + # While generating, we keep generating logits, throw away all but the final logits, + # and then use those logits to sample from the distribution We keep adding the + # sampled tokens to the end of tokens. + if use_past_kv_cache: + # We just take the final tokens, as a [batch, 1] tensor + if index > 0: + logits = self.forward( + tokens[:, -1:], + return_type="logits", + prepend_bos=prepend_bos, + padding_side=padding_side, + past_kv_cache=past_kv_cache, + ) + else: + logits = self.forward( + tokens, + return_type="logits", + prepend_bos=prepend_bos, + padding_side=padding_side, + past_kv_cache=past_kv_cache, + ) + else: + # We input the entire sequence, as a [batch, pos] tensor, since we aren't using + # the cache. + logits = self.forward( + tokens, + return_type="logits", + prepend_bos=prepend_bos, + padding_side=padding_side, + ) + final_logits = logits[:, -1, :] + + if do_sample: + sampled_tokens = utils.sample_logits( + final_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + tokens=tokens, + ).to(get_device_for_block_index(0, self.cfg)) + else: + sampled_tokens = final_logits.argmax(-1).to( + get_device_for_block_index(0, self.cfg) + ) + + if stop_at_eos: + # For all unfinished sequences, add on the next token. If a sequence was + # finished, throw away the generated token and add eos_token_for_padding + # instead. + sampled_tokens[finished_sequences] = eos_token_for_padding + finished_sequences.logical_or_( + torch.isin( + sampled_tokens.to(self.cfg.device), + torch.tensor(stop_tokens).to(self.cfg.device), + ) + ) + + new_tokens = sampled_tokens.unsqueeze(-1) + + # Accumulate tokens until we hit max_tokens_per_yield + if index == 0: + accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1) + tokens_since_last_yield = accumulated_tokens.shape[1] + else: + if accumulated_tokens is None: + accumulated_tokens = new_tokens + else: + accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) + tokens_since_last_yield += 1 + + if tokens_since_last_yield >= max_tokens_per_yield: + yield accumulated_tokens + tokens_since_last_yield = 0 + accumulated_tokens = None + + tokens = torch.cat([tokens, new_tokens], dim=-1) + + if stop_at_eos and finished_sequences.all(): + # Yield any remaining accumulated tokens before breaking + if accumulated_tokens is not None: + yield accumulated_tokens + break + + # Only yield remaining tokens if we didn't already yield them in the break case + if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()): + yield accumulated_tokens + # Give access to all weights as properties. @property def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: From 0db6d2284523d2a5afb080cbc0cd93bfc36d0d22 Mon Sep 17 00:00:00 2001 From: UFO-101 <47218308+UFO-101@users.noreply.github.com> Date: Wed, 22 Apr 2026 03:39:32 +0100 Subject: [PATCH 08/21] Add py.typed for type hints (#760) * Add attention_mask argument to loss_fn() and lm_cross_entropy_loss() and adjust the cross entropy calculation to ignore masked (padding) tokens. * updated lock file * locked numpy belo 2 --------- Co-authored-by: Bryce Meyer Co-authored-by: jlarson4 --- pyproject.toml | 9 ++++----- transformer_lens/py.typed | 0 2 files changed, 4 insertions(+), 5 deletions(-) create mode 100644 transformer_lens/py.typed diff --git a/pyproject.toml b/pyproject.toml index df57acf10..1af32755f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] authors=[{name="Neel Nanda", email="77788841+TransformerLensOrg@users.noreply.github.com"}] dependencies=[ - "accelerate>=0.23.0", # Needed for Llama Models + "accelerate>=0.23.0", # Needed for Llama Models "beartype>=0.14.1", "better-abc>=0.0.3", "datasets>=2.7.1", @@ -26,6 +26,7 @@ description="An implementation of transformers tailored for mechanistic interpretability." license={text="MIT"} name="transformer-lens" + packages=[{include="transformer_lens"}, {include="transformer_lens/py.typed"}] readme="README.md" requires-python=">=3.10,<4.0" version="0.0.0" @@ -85,13 +86,11 @@ "-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning", ] doctest_optionflags="NORMALIZE_WHITESPACE ELLIPSIS FLOAT_CMP" - markers=[ - "slow: marks tests as slow (deselect with '-m \"not slow\"')", - ] filterwarnings=[ - "ignore:pkg_resources is deprecated as an API:DeprecationWarning", "ignore:distutils Version classes are deprecated:DeprecationWarning", + "ignore:pkg_resources is deprecated as an API:DeprecationWarning", ] + markers=["slow: marks tests as slow (deselect with '-m \"not slow\"')"] pythonpath=["."] testpaths=["tests", "transformer_lens"] # Only test these directories diff --git a/transformer_lens/py.typed b/transformer_lens/py.typed new file mode 100644 index 000000000..e69de29bb From a4379f89738dbe72aeefeacdb000101a4da2d3d1 Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Tue, 21 Apr 2026 22:42:48 -0500 Subject: [PATCH 09/21] Created Baichuan Architecture adapter (#1262) --- .../test_baichuan_adapter.py | 771 ++++++++++++++++++ .../factories/architecture_adapter_factory.py | 3 + .../supported_architectures/__init__.py | 4 + .../supported_architectures/baichuan.py | 447 ++++++++++ .../model_registry/data/supported_models.json | 210 ++++- .../data/verification_history.json | 122 ++- .../tools/model_registry/verify_models.py | 1 + 7 files changed, 1554 insertions(+), 4 deletions(-) create mode 100644 tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py create mode 100644 transformer_lens/model_bridge/supported_architectures/baichuan.py diff --git a/tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py b/tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py new file mode 100644 index 000000000..d1e2348df --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py @@ -0,0 +1,771 @@ +"""Unit tests for BaichuanArchitectureAdapter. + +Tests cover: +- Config attributes +- Component mapping structure and HF module names +- Weight conversion keys/types +- split_qkv_matrix (W_pack) numerical correctness +- preprocess_weights (QKV split, fold_ln, NormHead normalization) +- Factory registration (both v1 and v2 class names) +""" + +from types import SimpleNamespace +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + JointQKVPositionEmbeddingsAttentionBridge, + RMSNormalizationBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.baichuan import ( + BaichuanArchitectureAdapter, + _BaichuanAttentionBridge, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 32, + d_model: int = 64, + n_layers: int = 2, + d_vocab: int = 100, + n_ctx: int = 128, +) -> TransformerBridgeConfig: + """Minimal TransformerBridgeConfig for Baichuan adapter tests.""" + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + default_prepend_bos=True, + architecture="BaichuanForCausalLM", + ) + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg(n_heads=8, d_model=64) + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> BaichuanArchitectureAdapter: + return BaichuanArchitectureAdapter(cfg) + + +def _make_w_pack_component(d_model: int) -> Any: + """Synthetic attention namespace with W_pack linear.""" + ns = SimpleNamespace() + ns.W_pack = nn.Linear(d_model, 3 * d_model, bias=False) + return ns + + +# --------------------------------------------------------------------------- +# Config attribute tests +# --------------------------------------------------------------------------- + + +class TestBaichuanAdapterConfig: + def test_normalization_type(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is True + + def test_gated_mlp(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is True + + def test_attn_only(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_uses_rms_norm(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_eps_attr(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.cfg.eps_attr == "variance_epsilon" + + def test_supports_fold_ln_false(self, adapter: BaichuanArchitectureAdapter) -> None: + assert adapter.supports_fold_ln is False + + +# --------------------------------------------------------------------------- +# Component mapping tests +# --------------------------------------------------------------------------- + + +class TestBaichuanAdapterComponentMapping: + @staticmethod + def _mapping(adapter: BaichuanArchitectureAdapter) -> dict[str, Any]: + mapping = adapter.component_mapping + assert mapping is not None + return mapping + + def test_embed_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["embed"], EmbeddingBridge) + assert mapping["embed"].name == "model.embed_tokens" + + def test_no_top_level_rotary_emb(self, adapter: BaichuanArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert "rotary_emb" not in mapping + + def test_blocks_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["blocks"], BlockBridge) + assert mapping["blocks"].name == "model.layers" + + def test_ln_final_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["ln_final"], RMSNormalizationBridge) + assert mapping["ln_final"].name == "model.norm" + + def test_unembed_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["unembed"], UnembeddingBridge) + assert mapping["unembed"].name == "lm_head" + + def test_ln1_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge) + assert blocks.submodules["ln1"].name == "input_layernorm" + + def test_ln2_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge) + assert blocks.submodules["ln2"].name == "post_attention_layernorm" + + def test_attn_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["attn"], JointQKVPositionEmbeddingsAttentionBridge) + assert blocks.submodules["attn"].name == "self_attn" + + def test_attn_qkv_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["attn"].submodules["qkv"].name == "W_pack" + + def test_attn_o_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["attn"].submodules["o"].name == "o_proj" + + def test_mlp_type_and_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["mlp"], GatedMLPBridge) + assert blocks.submodules["mlp"].name == "mlp" + + def test_mlp_gate_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["mlp"].submodules["gate"].name == "gate_proj" + + def test_mlp_in_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["mlp"].submodules["in"].name == "up_proj" + + def test_mlp_out_name(self, adapter: BaichuanArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["mlp"].submodules["out"].name == "down_proj" + + +# --------------------------------------------------------------------------- +# Weight conversion tests +# --------------------------------------------------------------------------- + + +class TestBaichuanAdapterWeightConversions: + def test_four_conversion_keys(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + assert len(convs) == 4 + + def test_qkvo_keys_present(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + for key in [ + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.o.weight", + ]: + assert key in convs + + def test_q_conversion_type(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + def test_q_rearrange_pattern(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_q_rearrange_n_equals_n_heads(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_k_rearrange_n_equals_n_heads(self, adapter: BaichuanArchitectureAdapter) -> None: + # Baichuan is MHA (no GQA), so K also uses n_heads + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.k.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_o_rearrange_pattern(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_no_source_key_on_q(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert conv.source_key is None + + +# --------------------------------------------------------------------------- +# split_qkv_matrix (W_pack) tests +# --------------------------------------------------------------------------- + + +class TestBaichuanSplitWPack: + def _adapter(self, n_heads: int = 8, d_model: int = 64) -> BaichuanArchitectureAdapter: + return BaichuanArchitectureAdapter(_make_cfg(n_heads=n_heads, d_model=d_model)) + + def test_returns_three_linears(self) -> None: + adapter = self._adapter() + attn = _make_w_pack_component(64) + q, k, v = adapter._split_baichuan_w_pack(attn) + assert isinstance(q, nn.Linear) + assert isinstance(k, nn.Linear) + assert isinstance(v, nn.Linear) + + def test_output_shapes(self) -> None: + d_model = 64 + adapter = self._adapter(d_model=d_model) + attn = _make_w_pack_component(d_model) + q, k, v = adapter._split_baichuan_w_pack(attn) + assert q.weight.shape == (d_model, d_model) + assert k.weight.shape == (d_model, d_model) + assert v.weight.shape == (d_model, d_model) + + def test_no_bias(self) -> None: + adapter = self._adapter() + attn = _make_w_pack_component(64) + q, k, v = adapter._split_baichuan_w_pack(attn) + assert q.bias is None + assert k.bias is None + assert v.bias is None + + def test_concatenated_split_correctness(self) -> None: + """W_pack = [Q|K|V] concatenated — verify split recovers each part.""" + d_model = 32 + adapter = self._adapter(n_heads=4, d_model=d_model) + attn = _make_w_pack_component(d_model) + # Fill W_pack: Q=1.0, K=2.0, V=3.0 + w = torch.zeros(3 * d_model, d_model) + w[:d_model, :] = 1.0 + w[d_model : 2 * d_model, :] = 2.0 + w[2 * d_model :, :] = 3.0 + attn.W_pack.weight = nn.Parameter(w) + + q, k, v = adapter._split_baichuan_w_pack(attn) + assert torch.all(q.weight == 1.0), "Q should be 1.0" + assert torch.all(k.weight == 2.0), "K should be 2.0" + assert torch.all(v.weight == 3.0), "V should be 3.0" + + def test_round_trip_recombine(self) -> None: + """Split → recombine must equal original W_pack weights.""" + d_model = 64 + adapter = self._adapter(d_model=d_model) + attn = _make_w_pack_component(d_model) + original_w = attn.W_pack.weight.data.clone() + + q, k, v = adapter._split_baichuan_w_pack(attn) + recombined = torch.cat([q.weight.data, k.weight.data, v.weight.data], dim=0) + assert torch.equal(recombined, original_w) + + def test_forward_output_shapes(self) -> None: + d_model = 64 + adapter = self._adapter(d_model=d_model) + attn = _make_w_pack_component(d_model) + q, k, v = adapter._split_baichuan_w_pack(attn) + x = torch.randn(2, 5, d_model) + assert q(x).shape == (2, 5, d_model) + assert k(x).shape == (2, 5, d_model) + assert v(x).shape == (2, 5, d_model) + + +# --------------------------------------------------------------------------- +# preprocess_weights tests +# --------------------------------------------------------------------------- + + +class TestBaichuanPreprocessWeights: + def _make_state_dict( + self, + adapter: BaichuanArchitectureAdapter, + d_model: int = 64, + n_layers: int = 2, + d_mlp: int = 16, + d_vocab: int = 100, + ln1_scale: float = 1.0, + qkv_val: float = 1.0, + ) -> dict[str, torch.Tensor]: + """Bridge-format state dict with fused W_pack for each layer.""" + state: dict[str, torch.Tensor] = {} + for i in range(n_layers): + state[f"blocks.{i}.attn.qkv.weight"] = torch.full((3 * d_model, d_model), qkv_val) + state[f"blocks.{i}.ln1.weight"] = torch.full((d_model,), ln1_scale) + state[f"blocks.{i}.ln2.weight"] = torch.ones(d_model) + state[f"blocks.{i}.mlp.gate.weight"] = torch.ones(d_mlp, d_model) + state[f"blocks.{i}.mlp.in.weight"] = torch.ones(d_mlp, d_model) + state[f"blocks.{i}.attn.o.weight"] = torch.ones(d_model, d_model) + state["ln_final.weight"] = torch.ones(d_model) + state["unembed.weight"] = torch.ones(d_vocab, d_model) + return state + + def test_fused_key_removed_and_split_keys_written(self) -> None: + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter) + result = adapter.preprocess_weights(sd) + assert "blocks.0.attn.qkv.weight" not in result + assert "blocks.0.attn.q.weight" in result + assert "blocks.0.attn.k.weight" in result + assert "blocks.0.attn.v.weight" in result + + def test_split_shapes(self) -> None: + d_model = 64 + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=d_model)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, d_model=d_model) + result = adapter.preprocess_weights(sd) + # Baichuan is MHA: Q, K, V each have shape [d_model, d_model] + assert result["blocks.0.attn.q.weight"].shape == (d_model, d_model) + assert result["blocks.0.attn.k.weight"].shape == (d_model, d_model) + assert result["blocks.0.attn.v.weight"].shape == (d_model, d_model) + + def test_ln1_fold_applied(self) -> None: + d_model = 64 + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=d_model)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, d_model=d_model, ln1_scale=2.0, qkv_val=1.0) + result = adapter.preprocess_weights(sd) + assert torch.all(result["blocks.0.attn.q.weight"] == 2.0) + assert torch.all(result["blocks.0.attn.k.weight"] == 2.0) + assert torch.all(result["blocks.0.attn.v.weight"] == 2.0) + + def test_ln1_reset_to_ones(self) -> None: + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, ln1_scale=3.0) + result = adapter.preprocess_weights(sd) + assert torch.all(result["blocks.0.ln1.weight"] == 1.0) + + def test_ln2_fold_applied(self) -> None: + d_model = 64 + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=d_model)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, d_model=d_model) + sd["blocks.0.ln2.weight"] = torch.full((d_model,), 3.0) + result = adapter.preprocess_weights(sd) + assert torch.all(result["blocks.0.mlp.gate.weight"] == 3.0) + assert torch.all(result["blocks.0.mlp.in.weight"] == 3.0) + + def test_no_fold_still_splits_qkv(self) -> None: + """Without fold_ln, W_pack must still be split for weight conversions.""" + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + adapter._fold_ln_requested = False + sd = self._make_state_dict(adapter) + result = adapter.preprocess_weights(sd) + assert "blocks.0.attn.qkv.weight" not in result + assert "blocks.0.attn.q.weight" in result + assert "blocks.0.attn.k.weight" in result + assert "blocks.0.attn.v.weight" in result + + def test_ln_final_fold_values(self) -> None: + """ln_final fold multiplies unembed weights by ln_final scale.""" + d_model = 64 + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=d_model)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, d_model=d_model) + sd["ln_final.weight"] = torch.full((d_model,), 2.0) + sd["unembed.weight"] = torch.ones(100, d_model) + result = adapter.preprocess_weights(sd) + assert torch.all(result["unembed.weight"] == 2.0) + assert torch.all(result["ln_final.weight"] == 1.0) + + def test_dtype_preserved(self) -> None: + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter) + sd = {k: v.to(torch.bfloat16) for k, v in sd.items()} + result = adapter.preprocess_weights(sd) + assert result["blocks.0.attn.q.weight"].dtype == torch.bfloat16 + + def test_all_layers_processed(self) -> None: + adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64, n_layers=3)) + adapter._fold_ln_requested = True + sd = self._make_state_dict(adapter, n_layers=3) + result = adapter.preprocess_weights(sd) + for i in range(3): + assert f"blocks.{i}.attn.qkv.weight" not in result + assert f"blocks.{i}.attn.q.weight" in result + + +# --------------------------------------------------------------------------- +# prepare_model tests (NormHead normalization) +# --------------------------------------------------------------------------- + + +class TestBaichuanPrepareModel: + def _adapter(self) -> BaichuanArchitectureAdapter: + return BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + + def test_normhead_weights_normalized(self) -> None: + """NormHead (has first_flag) should have row-normalized weights after prepare_model.""" + adapter = self._adapter() + lm_head = SimpleNamespace( + weight=nn.Parameter(torch.full((100, 64), 2.0)), + first_flag=True, + ) + hf_model = SimpleNamespace(lm_head=lm_head) + adapter.prepare_model(hf_model) + row_norms = lm_head.weight.data.float().norm(dim=-1) + assert torch.allclose(row_norms, torch.ones_like(row_norms), atol=1e-5) + + def test_regular_linear_unchanged(self) -> None: + """nn.Linear lm_head (no first_flag) should not be modified.""" + adapter = self._adapter() + lm_head = nn.Linear(64, 100, bias=False) + original_w = lm_head.weight.data.clone() + hf_model = SimpleNamespace(lm_head=lm_head) + adapter.prepare_model(hf_model) + assert torch.equal(lm_head.weight.data, original_w) + + def test_no_lm_head_is_noop(self) -> None: + """Model without lm_head should not raise.""" + adapter = self._adapter() + hf_model = SimpleNamespace() + adapter.prepare_model(hf_model) # should not raise + + def test_recomputes_rotary_from_scratch_when_inv_freq_is_meta(self) -> None: + """Baichuan2's inv_freq/cos_cached are plain attrs that land on meta under + HF v5 meta-init; prepare_model must recompute real values regardless.""" + adapter = self._adapter() + head_dim = adapter.cfg.d_model // adapter.cfg.n_heads + # Meta-device rotary matching v2's plain-attribute shape + rotary = SimpleNamespace( + inv_freq=torch.empty(head_dim // 2, device="meta"), + cos_cached=torch.empty(1, 1, 16, head_dim, device="meta"), + sin_cached=torch.empty(1, 1, 16, head_dim, device="meta"), + max_seq_len_cached=16, + ) + layer = SimpleNamespace(self_attn=SimpleNamespace(rotary_emb=rotary)) + hf_model = SimpleNamespace(model=SimpleNamespace(layers=[layer])) + + adapter.prepare_model(hf_model) + + assert rotary.inv_freq.device.type == "cpu" + assert rotary.cos_cached.device.type == "cpu" + assert rotary.sin_cached.device.type == "cpu" + assert rotary.cos_cached.shape == (1, 1, 16, head_dim) + # Sanity: cos(0) == 1 and position 0 of each head_dim element equals 1. + assert torch.allclose( + rotary.cos_cached[0, 0, 0, :], + torch.ones(head_dim), + atol=1e-6, + ) + + +# --------------------------------------------------------------------------- +# Factory registration tests +# --------------------------------------------------------------------------- + + +class TestBaichuanFactoryRegistration: + def test_factory_v2_key(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "BaichuanForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_factory_v1_key(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "BaiChuanForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_factory_v2_returns_baichuan_adapter(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg(n_heads=8, d_model=64) + cfg.architecture = "BaichuanForCausalLM" + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, BaichuanArchitectureAdapter) + + def test_factory_v1_returns_baichuan_adapter(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg(n_heads=8, d_model=64) + cfg.architecture = "BaiChuanForCausalLM" + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, BaichuanArchitectureAdapter) + + def test_import_from_init(self) -> None: + from transformer_lens.model_bridge.supported_architectures import ( + BaichuanArchitectureAdapter as FromInit, + ) + + assert FromInit is BaichuanArchitectureAdapter + + +# --------------------------------------------------------------------------- +# Attention bridge: position_ids → position_embeddings conversion +# --------------------------------------------------------------------------- + + +class _FakeRotary(nn.Module): + """Minimal stand-in for Baichuan's RotaryEmbedding (returns 4D cached cos/sin).""" + + def __init__(self, head_dim: int, max_seq_len: int) -> None: + super().__init__() + self.max_seq_len_cached = max_seq_len + # Fill with position-dependent values so tests can verify indexing. + cos = ( + torch.arange(max_seq_len, dtype=torch.float32)[:, None] + .expand(max_seq_len, head_dim) + .clone() + ) + sin = -cos + self.register_buffer("cos_cached", cos[None, None, :, :]) + self.register_buffer("sin_cached", sin[None, None, :, :]) + self.calls: list[int] = [] + + def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: + self.calls.append(seq_len) + cos_cached = self.cos_cached + sin_cached = self.sin_cached + assert isinstance(cos_cached, torch.Tensor) + assert isinstance(sin_cached, torch.Tensor) + return ( + cos_cached[:, :, :seq_len, :].to(dtype=x.dtype), + sin_cached[:, :, :seq_len, :].to(dtype=x.dtype), + ) + + +class _FakeAttention(nn.Module): + """nn.Module container that exposes a `rotary_emb` + `o_proj` to the bridge.""" + + def __init__(self, rotary: _FakeRotary, d_model: int) -> None: + super().__init__() + self.rotary_emb = rotary + self.o_proj = nn.Linear(d_model, d_model, bias=False) + nn.init.zeros_(self.o_proj.weight) + + +def _make_attention_bridge(cfg: TransformerBridgeConfig) -> _BaichuanAttentionBridge: + from transformer_lens.model_bridge.generalized_components import LinearBridge + + return _BaichuanAttentionBridge( + name="self_attn", + config=cfg, + split_qkv_matrix=lambda _c: ( + nn.Linear(cfg.d_model, cfg.d_model, bias=False), + nn.Linear(cfg.d_model, cfg.d_model, bias=False), + nn.Linear(cfg.d_model, cfg.d_model, bias=False), + ), + submodules={ + "qkv": LinearBridge(name="W_pack"), + "o": LinearBridge(name="o_proj"), + }, + ) + + +def _wire_bridge( + cfg: TransformerBridgeConfig, +) -> tuple[_BaichuanAttentionBridge, _FakeRotary, int]: + """Build a bridge with a fake HF attention (rotary + o_proj) attached.""" + head_dim = cfg.d_model // cfg.n_heads + bridge = _make_attention_bridge(cfg) + rotary = _FakeRotary(head_dim=head_dim, max_seq_len=32) + fake_attn = _FakeAttention(rotary, cfg.d_model) + bridge.set_original_component(fake_attn) + # `o` LinearBridge is normally wired by setup_components via component_mapping; + # wire it directly for unit tests that construct the bridge standalone. + bridge.o.set_original_component(fake_attn.o_proj) + return bridge, rotary, head_dim + + +class TestBaichuanAttentionBridgeRotary: + """Regression tests for the attention bridge's rotary + KV-cache contract.""" + + def test_uses_position_ids_when_position_embeddings_absent( + self, cfg: TransformerBridgeConfig + ) -> None: + bridge, rotary, head_dim = _wire_bridge(cfg) + + batch, seq = 1, 4 + q = torch.zeros(batch, seq, cfg.d_model) + k = torch.zeros_like(q) + v = torch.zeros_like(q) + position_ids = torch.tensor([[0, 1, 2, 3]]) + + attn_output, _, present = bridge._reconstruct_attention( + q, k, v, position_ids=position_ids, use_cache=True + ) + + # rotary_emb called once, with kv_seq_len=seq (no past) + assert rotary.calls == [seq] + assert attn_output.shape == (batch, seq, cfg.d_model) + assert present is not None + present_k, present_v = present + assert present_k.shape == (batch, cfg.n_heads, seq, head_dim) + assert present_v.shape == (batch, cfg.n_heads, seq, head_dim) + + def test_preserves_explicit_position_embeddings(self, cfg: TransformerBridgeConfig) -> None: + bridge, rotary, head_dim = _wire_bridge(cfg) + + batch, seq = 1, 4 + q = torch.zeros(batch, seq, cfg.d_model) + k = torch.zeros_like(q) + v = torch.zeros_like(q) + explicit = ( + torch.ones(batch, seq, head_dim) * 7, + torch.ones(batch, seq, head_dim) * 9, + ) + + bridge._reconstruct_attention( + q, + k, + v, + position_embeddings=explicit, + position_ids=torch.tensor([[0, 1, 2, 3]]), + use_cache=True, + ) + # Caller-supplied embeddings must win; rotary_emb must not be called. + assert rotary.calls == [] + + def test_use_cache_false_returns_none_present(self, cfg: TransformerBridgeConfig) -> None: + bridge, _, _ = _wire_bridge(cfg) + q = torch.zeros(1, 4, cfg.d_model) + _, _, present = bridge._reconstruct_attention( + q, q.clone(), q.clone(), position_ids=torch.tensor([[0, 1, 2, 3]]) + ) + assert present is None + + def test_concats_past_key_value_along_seq_dim(self, cfg: TransformerBridgeConfig) -> None: + """With past cache of length P and current seq S, the present cache's + k/v have seq dim P+S and rotary is requested with kv_seq_len=P+S.""" + bridge, rotary, head_dim = _wire_bridge(cfg) + + batch, past_len, seq = 1, 3, 2 + past_k = torch.randn(batch, cfg.n_heads, past_len, head_dim) + past_v = torch.randn(batch, cfg.n_heads, past_len, head_dim) + + q = torch.zeros(batch, seq, cfg.d_model) + k = torch.zeros_like(q) + v = torch.zeros_like(q) + # HF's Model.forward generates position_ids offset by past_len. + position_ids = torch.tensor([[past_len, past_len + 1]]) + + _, _, present = bridge._reconstruct_attention( + q, + k, + v, + past_key_value=(past_k, past_v), + position_ids=position_ids, + use_cache=True, + ) + assert rotary.calls == [past_len + seq] + assert present is not None + present_k, present_v = present + assert present_k.shape == (batch, cfg.n_heads, past_len + seq, head_dim) + assert present_v.shape == (batch, cfg.n_heads, past_len + seq, head_dim) + # First past_len slots must be the provided past, unchanged. + assert torch.equal(present_k[:, :, :past_len, :], past_k) + assert torch.equal(present_v[:, :, :past_len, :], past_v) + + +# --------------------------------------------------------------------------- +# prepare_loading: bitsandbytes preflight +# --------------------------------------------------------------------------- + + +class TestBaichuanPrepareLoadingBitsandbytes: + """The adapter must point users at `uv sync --group quantization` when bnb is missing.""" + + def test_preflight_raises_clean_import_error( + self, adapter: BaichuanArchitectureAdapter, monkeypatch: pytest.MonkeyPatch + ) -> None: + import transformer_lens.model_bridge.supported_architectures.baichuan as baichuan_mod + + # Force the preflight path: make find_spec report bitsandbytes missing, + # and make get_class_from_dynamic_module surface the transformers-style + # "requires the following packages... bitsandbytes" error. + monkeypatch.setattr(baichuan_mod.importlib.util, "find_spec", lambda name: None) + + def _raise_bnb(*_a: Any, **_k: Any) -> None: + raise ImportError( + "This modeling file requires the following packages that were " + "not found in your environment: bitsandbytes" + ) + + import transformers.dynamic_module_utils as dmu + + monkeypatch.setattr(dmu, "get_class_from_dynamic_module", _raise_bnb) + + with pytest.raises(ImportError, match="uv sync --group quantization"): + adapter.prepare_loading("baichuan-inc/Baichuan2-7B-Chat", {}) + + def test_preflight_no_false_positive_when_bnb_installed( + self, adapter: BaichuanArchitectureAdapter, monkeypatch: pytest.MonkeyPatch + ) -> None: + """If bnb IS installed, the transformers error won't mention bnb, so no raise.""" + import transformers.dynamic_module_utils as dmu + + def _raise_generic(*_a: Any, **_k: Any) -> None: + raise ValueError("some unrelated loader failure") + + monkeypatch.setattr(dmu, "get_class_from_dynamic_module", _raise_generic) + # Must not raise — the generic failure path is swallowed (remote load + # may legitimately fail for offline tests, e.g. no network access). + adapter.prepare_loading("baichuan-inc/Baichuan2-7B-Chat", {}) diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index a55f51a5a..b5432aff1 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -7,6 +7,7 @@ from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.supported_architectures import ( ApertusArchitectureAdapter, + BaichuanArchitectureAdapter, BertArchitectureAdapter, BloomArchitectureAdapter, CodeGenArchitectureAdapter, @@ -63,6 +64,8 @@ # Export supported architectures SUPPORTED_ARCHITECTURES = { "ApertusForCausalLM": ApertusArchitectureAdapter, + "BaiChuanForCausalLM": BaichuanArchitectureAdapter, + "BaichuanForCausalLM": BaichuanArchitectureAdapter, "BertForMaskedLM": BertArchitectureAdapter, "BloomForCausalLM": BloomArchitectureAdapter, "CodeGenForCausalLM": CodeGenArchitectureAdapter, diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 7f990e393..772d76942 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -6,6 +6,9 @@ from transformer_lens.model_bridge.supported_architectures.apertus import ( ApertusArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.baichuan import ( + BaichuanArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.bert import ( BertArchitectureAdapter, ) @@ -165,6 +168,7 @@ __all__ = [ "ApertusArchitectureAdapter", + "BaichuanArchitectureAdapter", "BertArchitectureAdapter", "BloomArchitectureAdapter", "CodeGenArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/baichuan.py b/transformer_lens/model_bridge/supported_architectures/baichuan.py new file mode 100644 index 000000000..a50fabc37 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/baichuan.py @@ -0,0 +1,447 @@ +"""Baichuan architecture adapter. + +Supports both BaiChuanForCausalLM (v1) and BaichuanForCausalLM (v2). +Both use combined QKV via W_pack with RoPE, RMSNorm, and gated MLP. +""" + +import importlib.util +import sys +from typing import Any + +import torch +import torch.nn as nn + +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5 +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + JointQKVPositionEmbeddingsAttentionBridge, + LinearBridge, + RMSNormalizationBridge, + UnembeddingBridge, +) + + +class _BaichuanAttentionBridge(JointQKVPositionEmbeddingsAttentionBridge): + """Attention bridge for Baichuan's v4-era decoder-layer contract. + + Baichuan predates HF's Cache API and differs from the base bridge in two + ways we have to own: + + 1. **Rotary from position_ids**: HF passes `position_ids` (not a + pre-computed `position_embeddings` tuple), so we call the per-layer + `rotary_emb(v, seq_len=kv_seq_len)` ourselves and slice cos/sin by + `position_ids`. + 2. **Legacy (k, v) cache tuple**: HF's DecoderLayer passes + `past_key_value=(k, v)` (singular, per-layer legacy tuple) and expects + `self_attn(...)` to return a matching `(k_full, v_full)` as + `present_key_value` so Model.forward's `next_decoder_cache` accumulates + real tensors. The base bridge's `_update_kv_cache` only handles the + Cache-object plural path, so we reimplement the attention body here + (mirroring HF's own Attention.forward). + """ + + def _reconstruct_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs + ) -> tuple: + assert self.original_component is not None + assert self.config is not None + num_heads = self.config.n_heads + num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads + + q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads( + q, k, v, num_heads, num_kv_heads + ) + + past_kv_raw = kwargs.get("past_key_value") + past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None + if ( + isinstance(past_kv_raw, tuple) + and len(past_kv_raw) >= 2 + and isinstance(past_kv_raw[0], torch.Tensor) + and isinstance(past_kv_raw[1], torch.Tensor) + ): + past_key_value = (past_kv_raw[0], past_kv_raw[1]) + past_len = past_key_value[0].shape[-2] if past_key_value is not None else 0 + + # Rotary: derive cos/sin over the full kv_seq_len, index by position_ids. + if "position_embeddings" not in kwargs: + rotary_emb = getattr(self.original_component, "rotary_emb", None) + position_ids = kwargs.get("position_ids") + if rotary_emb is not None and position_ids is not None: + kv_seq_len = seq_len + past_len + cos, sin = rotary_emb(v, seq_len=kv_seq_len) + cos = cos.squeeze(1).squeeze(0)[position_ids] + sin = sin.squeeze(1).squeeze(0)[position_ids] + kwargs["position_embeddings"] = (cos, sin) + + position_embeddings = kwargs.get("position_embeddings") + if position_embeddings is not None and isinstance(position_embeddings, tuple): + cos, sin = self._apply_position_embedding_hooks(position_embeddings) + q, k = self._apply_rotary_pos_emb(q, k, cos, sin) + + # Concat prior (k, v) — already rotary-applied from its own step. + if past_key_value is not None: + k = torch.cat([past_key_value[0], k], dim=-2) + v = torch.cat([past_key_value[1], v], dim=-2) + + # Build present cache from pre-GQA-expansion (k, v) so downstream + # steps don't pay for duplicated heads. + use_cache = bool(kwargs.get("use_cache", False)) + present_key_value = (k, v) if use_cache else None + + if num_kv_heads != num_heads: + n_rep = num_heads // num_kv_heads + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + + kv_seq_len = k.shape[-2] + attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** (-0.5)) + attention_mask = kwargs.get("attention_mask", None) + attn_scores = self._apply_reconstruct_attention_mask( + attn_scores=attn_scores, + attention_mask=attention_mask, + seq_len=kv_seq_len, + q_seq_len=seq_len, + ) + attn_scores = self.hook_attn_scores(attn_scores) + attn_weights = self._softmax_dropout_pattern(attn_scores) + attn_output = torch.matmul(attn_weights, v) + attn_output = self._reshape_attn_output( + attn_output, batch_size, seq_len, num_heads, head_dim + ) + if ( + bool(getattr(self.config, "use_attn_result", False)) + and hasattr(self, "o") + and self.o.original_component is not None + ): + attn_output = self.o.hook_in(attn_output) + z_4d = attn_output.view(batch_size, seq_len, num_heads, head_dim) + attn_output = self._compute_per_head_result(z_4d, num_heads, head_dim) + else: + attn_output = self._apply_output_projection(attn_output) + + return (attn_output, attn_weights, present_key_value) + + +def _patch_init_weights_for_baichuan() -> None: + """Prevent _init_weights from re-randomizing loaded checkpoint weights. + + Transformers v5 calls _init_weights on all modules after weight + materialization. For modules with real (non-meta) tensors, we must + skip re-initialization to preserve the loaded checkpoint values. + """ + for key in list(sys.modules.keys()): + if "baichuan" not in key.lower() or "modeling" not in key.lower(): + continue + module = sys.modules[key] + # Both v1 (BaiChuan) and v2 (Baichuan) define a PreTrainedModel subclass + for cls_name in ("BaiChuanPreTrainedModel", "BaichuanPreTrainedModel", "PreTrainedModel"): + pretrained_cls = getattr(module, cls_name, None) + if pretrained_cls is None or getattr(pretrained_cls, "_tl_patched", False): + continue + # Only patch classes that define their own _init_weights + if "_init_weights" not in pretrained_cls.__dict__: + continue + + original_init_weights = pretrained_cls._init_weights + + def safe_init_weights(self, mod, _original=original_init_weights): # type: ignore[no-untyped-def] + first_param = next(mod.parameters(), None) + if first_param is not None and first_param.device.type != "meta": + return + _original(self, mod) + + pretrained_cls._init_weights = safe_init_weights + pretrained_cls._tl_patched = True + + +class BaichuanArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for Baichuan models (v1 and v2). + + Baichuan uses combined QKV via W_pack (nn.Linear(h, 3*h)) with RoPE, + RMSNorm, and gated MLP (SwiGLU). Per-layer rotary embeddings. + + Optional Parameters (may not exist in state_dict): + ------------------------------------------------- + Baichuan models do NOT have biases on any projection: + + - blocks.{i}.attn.b_Q / b_K / b_V / b_O — no bias + - blocks.{i}.mlp.b_gate / b_in / b_out — no bias + - blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias + """ + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.normalization_type = "RMS" + self.cfg.positional_embedding_type = "rotary" + self.cfg.final_rms = True + self.cfg.gated_mlp = True + self.cfg.attn_only = False + self.cfg.uses_rms_norm = True + self.cfg.eps_attr = "variance_epsilon" + + # Fused W_pack prevents standard fold_ln from reaching Q/K/V separately. + # preprocess_weights() handles it instead. + self.supports_fold_ln = False + + self.weight_processing_conversions = { + "blocks.{i}.attn.q.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), + ), + "blocks.{i}.attn.k.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), + ), + "blocks.{i}.attn.v.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), + ), + "blocks.{i}.attn.o.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=cfg.n_heads), + ), + } + + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), + "attn": _BaichuanAttentionBridge( + name="self_attn", + config=self.cfg, + split_qkv_matrix=self._split_baichuan_w_pack, + submodules={ + "qkv": LinearBridge(name="W_pack"), + "o": LinearBridge(name="o_proj"), + }, + ), + "mlp": GatedMLPBridge( + name="mlp", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), + } + + def _split_baichuan_w_pack( + self, attention_component: Any + ) -> tuple[nn.Linear, nn.Linear, nn.Linear]: + """Split Baichuan's W_pack into separate Q, K, V linear modules. + + W_pack is a simple concatenation: [Q | K | V], each of size hidden_size. + No interleaving, no GQA — all three chunks are equal size. + """ + w_pack = attention_component.W_pack + weight = w_pack.weight.data + d_model = weight.shape[1] + hidden_size = d_model # Q, K, V each have hidden_size output features + + q_w = weight[:hidden_size, :] + k_w = weight[hidden_size : 2 * hidden_size, :] + v_w = weight[2 * hidden_size :, :] + + def _make_linear(w: torch.Tensor) -> nn.Linear: + lin = nn.Linear(d_model, hidden_size, bias=False) + lin.weight = nn.Parameter(w) + return lin + + return _make_linear(q_w), _make_linear(k_w), _make_linear(v_w) + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Inject per-layer rotary embedding for component testing.""" + try: + rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb + except (AttributeError, IndexError): + return + + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb) + + def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: + """Patch transformers v5 incompatibilities before from_pretrained runs.""" + patch_dynamic_cache_v5() + + # Force-import the remote modeling module so we can patch _init_weights. + # Baichuan2 variants ship quantizer.py which imports bitsandbytes; + # transformers' check_imports scans every .py file in the repo and + # raises ImportError if bitsandbytes is missing, even though quantizer + # is not used in normal inference. Catch that case and tell the user + # how to install the optional dependency group. + try: + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + last_exc: Exception | None = None + # Try both class names (v1 and v2) + for cls_name in ( + "modeling_baichuan.BaichuanForCausalLM", + "modeling_baichuan.BaiChuanForCausalLM", + ): + try: + get_class_from_dynamic_module(cls_name, model_name) + last_exc = None + break + except Exception as exc: + last_exc = exc + continue + if last_exc is not None and "bitsandbytes" in str(last_exc): + if importlib.util.find_spec("bitsandbytes") is None: + raise ImportError( + "Baichuan2 variants require `bitsandbytes` for " + "trust_remote_code loading (their shipped quantizer.py " + "imports it). Install the quantization extras: " + "`uv sync --group quantization`." + ) from last_exc + except ImportError: + raise + except Exception: + pass + + _patch_init_weights_for_baichuan() + + def prepare_model(self, hf_model: Any) -> None: + """Fix rotary caches and normalize NormHead weights before bridge creation. + + RotaryEmbedding differs between v1 and v2: + - v1 (Baichuan-7B): `inv_freq` is a persistent buffer, loaded from the + checkpoint as bfloat16, but `cos_cached`/`sin_cached` are non-persistent + and materialize as garbage under meta-init. + - v2 (Baichuan2-*): `inv_freq`, `cos_cached`, `sin_cached` are all plain + attributes (no `register_buffer`). v5's meta-init materializes them on + meta, and nothing in the checkpoint overwrites them. + + Both cases are resolved by computing inv_freq + caches from scratch at + float32 using config-derived head_dim and base=10000. Recomputing v1 at + float32 is also an upgrade over its bfloat16 checkpoint values. + + Baichuan2 Chat also uses NormHead which row-normalizes lm_head during + forward. We apply that once here so the bridge sees the normalized + weights directly without needing NormHead's forward path. + """ + # Pick a real device/dtype by scanning real (non-meta) parameters. + target_device = torch.device("cpu") + params_fn = getattr(hf_model, "parameters", None) + if callable(params_fn): + for param in params_fn(): + if param.device.type != "meta": + target_device = param.device + break + + head_dim = self.cfg.d_model // self.cfg.n_heads + base = 10000.0 + + model_core = getattr(hf_model, "model", None) + if model_core is not None: + for layer in getattr(model_core, "layers", []): + rotary = getattr(getattr(layer, "self_attn", None), "rotary_emb", None) + if rotary is None: + continue + max_seq = getattr(rotary, "max_seq_len_cached", self.cfg.n_ctx or 4096) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, head_dim, 2, device=target_device, dtype=torch.float32) + / head_dim + ) + ) + t = torch.arange(max_seq, device=target_device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + rotary.inv_freq = inv_freq + rotary.cos_cached = emb.cos()[None, None, :, :] + rotary.sin_cached = emb.sin()[None, None, :, :] + + # Normalize NormHead weights (Baichuan2 Chat) + lm_head = getattr(hf_model, "lm_head", None) + if lm_head is not None and hasattr(lm_head, "first_flag"): + w = lm_head.weight.data + lm_head.weight.data = torch.nn.functional.normalize(w, dim=-1) + + def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Split fused W_pack QKV and optionally fold layer norms.""" + fold_ln = getattr(self, "_fold_ln_requested", True) + if not fold_ln: + # Still need to split W_pack into Q/K/V for weight conversions + for i in range(self.cfg.n_layers): + qkv_key = f"blocks.{i}.attn.qkv.weight" + if qkv_key not in state_dict: + continue + w = state_dict[qkv_key] + hidden_size = w.shape[1] + q_w = w[:hidden_size, :] + k_w = w[hidden_size : 2 * hidden_size, :] + v_w = w[2 * hidden_size :, :] + state_dict[f"blocks.{i}.attn.q.weight"] = q_w + state_dict[f"blocks.{i}.attn.k.weight"] = k_w + state_dict[f"blocks.{i}.attn.v.weight"] = v_w + del state_dict[qkv_key] + return state_dict + + for i in range(self.cfg.n_layers): + # --- Fold ln1 into Q/K/V (split from W_pack) --- + qkv_key = f"blocks.{i}.attn.qkv.weight" + ln1_key = f"blocks.{i}.ln1.weight" + if qkv_key in state_dict and ln1_key in state_dict: + ln1_w = state_dict[ln1_key].float() + w = state_dict[qkv_key].float() + orig_dtype = state_dict[qkv_key].dtype + hidden_size = w.shape[1] + + q_w = w[:hidden_size, :] + k_w = w[hidden_size : 2 * hidden_size, :] + v_w = w[2 * hidden_size :, :] + + state_dict[f"blocks.{i}.attn.q.weight"] = (q_w * ln1_w[None, :]).to(orig_dtype) + state_dict[f"blocks.{i}.attn.k.weight"] = (k_w * ln1_w[None, :]).to(orig_dtype) + state_dict[f"blocks.{i}.attn.v.weight"] = (v_w * ln1_w[None, :]).to(orig_dtype) + del state_dict[qkv_key] + state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key]) + + # --- Fold ln2 into MLP gate and up projections --- + ln2_key = f"blocks.{i}.ln2.weight" + if ln2_key in state_dict: + ln2_w = state_dict[ln2_key].float() + for mlp_key in [ + f"blocks.{i}.mlp.gate.weight", + f"blocks.{i}.mlp.in.weight", + ]: + if mlp_key in state_dict: + orig_dtype = state_dict[mlp_key].dtype + state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to( + orig_dtype + ) + state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key]) + + # --- Fold ln_final into unembed --- + ln_final_key = "ln_final.weight" + unembed_key = "unembed.weight" + if ln_final_key in state_dict and unembed_key in state_dict: + ln_w = state_dict[ln_final_key].float() + u_w = state_dict[unembed_key].float() + orig_dtype = state_dict[unembed_key].dtype + if u_w.shape[-1] == ln_w.shape[0]: + state_dict[unembed_key] = (u_w * ln_w[None, :]).to(orig_dtype) + elif u_w.shape[0] == ln_w.shape[0]: + state_dict[unembed_key] = (u_w * ln_w[:, None]).to(orig_dtype) + state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key]) + + return state_dict diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index 9ea753a6b..c0fb70907 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -6,9 +6,9 @@ "min_downloads": 500, "scan_duration_seconds": 4.9 }, - "total_architectures": 48, - "total_models": 9056, - "total_verified": 709, + "total_architectures": 50, + "total_models": 9068, + "total_verified": 711, "models": [ { "architecture_id": "Qwen3NextForCausalLM", @@ -125195,6 +125195,210 @@ "phase4_score": null, "phase7_score": null, "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan2-7B-Chat", + "status": 3, + "verified_date": "2026-04-21", + "metadata": { + "downloads": 89081, + "total_params": null + }, + "note": "Below threshold: P1=0.0% < 100.0% (failed: load_bridge_unprocessed) \u2014 Failed to load unprocessed TransformerBridge: This modeling file requires the following packages that were not found in your environment: bitsandbytes", + "phase1_score": 0.0, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan2-13B-Chat", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 7963, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan-13B-Chat", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 6570, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan2-7B-Base", + "status": 1, + "verified_date": "2026-04-21", + "metadata": { + "downloads": 2007, + "total_params": null + }, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 100.0, + "phase4_score": 94.6, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "sakuraumi/Sakura-13B-Galgame", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1800, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "zxbsmk/NSFW_13B_sft", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1786, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan2-13B-Base", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1773, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "Wuyanzzh/NSFW_13B_sft", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1486, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "baichuan-inc/Baichuan-13B-Base", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1295, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "FreedomIntelligence/HuatuoGPT2-7B", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1252, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaichuanForCausalLM", + "model_id": "DuJinHua/AiMed2", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 951, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "BaiChuanForCausalLM", + "model_id": "baichuan-inc/Baichuan-7B", + "status": 1, + "verified_date": "2026-04-21", + "metadata": { + "downloads": 50000, + "total_params": null + }, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 100.0, + "phase4_score": 92.0, + "phase7_score": null, + "phase8_score": null } ] } diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index 657c09326..c87d21798 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-04-15T16:28:06.314994", + "last_updated": "2026-04-21T20:10:35.469418", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -11730,6 +11730,126 @@ "notes": "Full verification completed with issues, low text quality", "invalidated": false, "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan2-7B-Chat", + "architecture_id": "BaichuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: load_bridge_unprocessed) \u2014 Failed to load unprocessed TransformerBridge: This modeling file requires the following packages that were not found in your environment: bitsandbytes", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan2-7B-Base", + "architecture_id": "BaichuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=0.0% < 100.0% (failed: load_bridge_unprocessed) \u2014 Failed to load unprocessed TransformerBridge: This modeling file requires the following packages that were not found in your environment: bitsandbytes", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=74.608353, mean_rel=1.619285", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=78.619270, mean_rel=1.866265", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=nan, mean_rel=nan", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=33.073044, mean_rel=0.316714", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=33.073044, mean_rel=0.316714", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P2=69.2% < 75.0% (failed: generation, generation_with_kv_cache, multiple_generation \u2014 Generation failed: 'NoneType' object is not subscriptable", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan2-7B-Base", + "architecture_id": "BaichuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass); P2=7.7% < 75.0% (failed: generation, gene \u2014 Forward pass failed: Cannot copy out of meta tensor; no data!", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan2-7B-Base", + "architecture_id": "BaichuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "baichuan-inc/Baichuan-7B", + "architecture_id": "BaiChuanForCausalLM", + "verified_date": "2026-04-21", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null } ] } diff --git a/transformer_lens/tools/model_registry/verify_models.py b/transformer_lens/tools/model_registry/verify_models.py index 1497a1cef..4798db1c6 100644 --- a/transformer_lens/tools/model_registry/verify_models.py +++ b/transformer_lens/tools/model_registry/verify_models.py @@ -60,6 +60,7 @@ # Architectures added via the TransformerBridge system that need trust_remote_code=True. # These are not in the legacy NEED_REMOTE_CODE_MODELS tuple (loading_from_pretrained.py). _BRIDGE_REMOTE_CODE_PREFIXES: tuple[str, ...] = ( + "baichuan-inc/", # BaichuanForCausalLM — ships own modeling_baichuan.py "internlm/", # InternLM2ForCausalLM — ships own modeling_internlm2.py ) From 717899ea0f58f00ea725b15d24083ea698354cfc Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 21 Apr 2026 21:25:42 -0700 Subject: [PATCH 10/21] Make `FactoredMatrix` compatible with tensor-like arguments (#599) * Make `FactoredMatrix` compatible with tensor-like arguments I'd like to be able to use `FactoredMatrix` with things that implement the interface of `torch.Tensor` without subclassing it. This slight change allows `FactoredMatrix` to work with such classes rather than returning `None` in various places. * Added test and properly typed he methods to include details of this fix --------- Co-authored-by: Bryce Meyer Co-authored-by: jlarson4 --- .../test_multiply_by_tensor_like.py | 109 ++++++++++++++++++ transformer_lens/FactoredMatrix.py | 82 ++++++++++--- 2 files changed, 175 insertions(+), 16 deletions(-) create mode 100644 tests/unit/factored_matrix/test_multiply_by_tensor_like.py diff --git a/tests/unit/factored_matrix/test_multiply_by_tensor_like.py b/tests/unit/factored_matrix/test_multiply_by_tensor_like.py new file mode 100644 index 000000000..e9482b3e3 --- /dev/null +++ b/tests/unit/factored_matrix/test_multiply_by_tensor_like.py @@ -0,0 +1,109 @@ +"""Tests that FactoredMatrix matmul works with tensor-like objects. + +A "tensor-like" object is one that quacks like a torch.Tensor (supports the +operations FactoredMatrix needs — .ndim, .size, .unsqueeze, __matmul__) but +isn't a torch.Tensor subclass. This is useful for things like jaxtyping wrappers +or custom array types. +""" + +from torch import randn +from torch.testing import assert_close + +from transformer_lens import FactoredMatrix + + +class TensorLike: + """A wrapper that exposes the tensor protocol without subclassing torch.Tensor. + + Implements just enough of the protocol that FactoredMatrix can multiply + with it: matmul, ndim, size, shape, unsqueeze, squeeze, broadcast_to. + """ + + def __init__(self, tensor): + self._tensor = tensor + + @property + def ndim(self): + return self._tensor.ndim + + @property + def shape(self): + return self._tensor.shape + + def size(self, dim=None): + return self._tensor.size() if dim is None else self._tensor.size(dim) + + def unsqueeze(self, dim): + return TensorLike(self._tensor.unsqueeze(dim)) + + def squeeze(self, dim): + return TensorLike(self._tensor.squeeze(dim)) + + def broadcast_to(self, shape): + return TensorLike(self._tensor.broadcast_to(shape)) + + def __matmul__(self, other): + if isinstance(other, FactoredMatrix): + # Defer to FactoredMatrix.__rmatmul__ so the result is a FactoredMatrix + return NotImplemented + if isinstance(other, TensorLike): + return TensorLike(self._tensor @ other._tensor) + return TensorLike(self._tensor @ other) + + def __rmatmul__(self, other): + if isinstance(other, TensorLike): + return TensorLike(other._tensor @ self._tensor) + return TensorLike(other @ self._tensor) + + +def test_left_multiply_factored_matrix_by_tensor_like_matrix(): + """factored_matrix @ tensor_like_matrix should not return None.""" + a = randn(2, 3) + b = randn(3, 4) + matrix = randn(4, 5) + factored_matrix = FactoredMatrix(a, b) + + result = factored_matrix @ TensorLike(matrix) + + assert result is not None, "matmul with tensor-like silently returned None" + assert isinstance(result, FactoredMatrix) + expected = (a @ b) @ matrix + assert isinstance(result.AB, TensorLike) + assert_close(result.AB._tensor, expected) + + +def test_right_multiply_factored_matrix_by_tensor_like_matrix(): + """tensor_like_matrix @ factored_matrix should not return None.""" + a = randn(3, 4) + b = randn(4, 6) + matrix = randn(5, 3) + factored_matrix = FactoredMatrix(a, b) + + result = TensorLike(matrix) @ factored_matrix + + assert result is not None, "rmatmul with tensor-like silently returned None" + assert isinstance(result, FactoredMatrix) + expected = matrix @ (a @ b) + assert isinstance(result.AB, TensorLike) + assert_close(result.AB._tensor, expected) + + +def test_left_multiply_factored_matrix_by_tensor_like_vector(): + """factored_matrix @ tensor_like_vector should dispatch through the vector path. + + The vector branch of FactoredMatrix.__matmul__ collapses to a single tensor + via unsqueeze/squeeze rather than wrapping in a new FactoredMatrix. This test + exercises that path and verifies the TensorLike protocol methods (unsqueeze, + squeeze, __rmatmul__) are correctly invoked. + """ + a = randn(2, 3) + b = randn(3, 4) + vector = randn(4) + factored_matrix = FactoredMatrix(a, b) + + result = factored_matrix @ TensorLike(vector) + + # The fix's core guarantee: the dispatch produces a result instead of None + assert isinstance(result, TensorLike) + expected = (a @ b) @ vector + assert_close(result._tensor, expected) diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index 2f69220df..0c7ce3610 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -7,7 +7,7 @@ from __future__ import annotations from functools import lru_cache -from typing import List, Tuple, Union, overload +from typing import Any, List, Protocol, Tuple, Union, cast, overload, runtime_checkable import torch from jaxtyping import Complex, Float @@ -15,6 +15,39 @@ import transformer_lens.utilities.tensors as tensor_utils +@runtime_checkable +class TensorLike(Protocol): + """Minimal tensor protocol that FactoredMatrix accepts in place of torch.Tensor. + + Allows duck-typed inputs (e.g. jaxtyping wrappers, custom array types) that + aren't torch.Tensor subclasses but support the operations FactoredMatrix uses + when constructing, multiplying, and broadcasting its A and B factors. + """ + + @property + def ndim(self) -> int: + ... + + @property + def shape(self) -> Any: + ... + + def size(self, dim: int) -> int: + ... + + def unsqueeze(self, dim: int) -> Any: + ... + + def squeeze(self, dim: int) -> Any: + ... + + def broadcast_to(self, shape: Any) -> Any: + ... + + def __matmul__(self, other: Any) -> Any: + ... + + class FactoredMatrix: """ Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD. @@ -22,11 +55,21 @@ class FactoredMatrix: def __init__( self, - A: Float[torch.Tensor, "... ldim mdim"], - B: Float[torch.Tensor, "... mdim rdim"], + A: Union[Float[torch.Tensor, "... ldim mdim"], TensorLike], + B: Union[Float[torch.Tensor, "... mdim rdim"], TensorLike], ): - self.A = A - self.B = B + """Construct a FactoredMatrix from factors A and B. + + A and B may be torch.Tensor or TensorLike duck types. TensorLike inputs + are only fully supported by matmul-family operations (``@``, ``AB``, + ``BA``); operations like ``svd()``, ``norm()``, ``transpose()``, + ``__getitem__``, and eigenvalue methods require both factors to be + actual torch.Tensor and will raise AttributeError on TensorLike inputs. + """ + # Cast to Tensor for type-checker purposes. At runtime A and B may be + # TensorLike duck types; the class methods trust the protocol. + self.A: torch.Tensor = cast(torch.Tensor, A) + self.B: torch.Tensor = cast(torch.Tensor, B) assert self.A.size(-1) == self.B.size( -2 ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}" @@ -74,9 +117,12 @@ def __matmul__( Float[torch.Tensor, "... rdim new_rdim"], Float[torch.Tensor, "rdim"], "FactoredMatrix", + TensorLike, ], - ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"]]: - if isinstance(other, torch.Tensor): + ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"], TensorLike]: + if isinstance(other, FactoredMatrix): + return (self @ other.A) @ other.B + else: if other.ndim < 2: # It's a vector, so we collapse the factorisation and just return a vector # Squeezing/Unsqueezing is to preserve broadcasting working nicely @@ -86,11 +132,11 @@ def __matmul__( other.size(-2) == self.rdim ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" if self.rdim > self.mdim: - return FactoredMatrix(self.A, self.B @ other) + # other is Tensor or TensorLike; runtime delegates to + # the appropriate __matmul__/__rmatmul__ overload. + return FactoredMatrix(self.A, self.B @ cast(torch.Tensor, other)) else: return FactoredMatrix(self.AB, other) - elif isinstance(other, FactoredMatrix): - return (self @ other.A) @ other.B @overload def __rmatmul__( # type: ignore @@ -115,9 +161,12 @@ def __rmatmul__( # type: ignore Float[torch.Tensor, "... new_rdim ldim"], Float[torch.Tensor, "ldim"], "FactoredMatrix", + TensorLike, ], - ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"]]: - if isinstance(other, torch.Tensor): + ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"], TensorLike]: + if isinstance(other, FactoredMatrix): + return other.A @ (other.B @ self) + else: assert ( other.size(-1) == self.ldim ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" @@ -128,8 +177,6 @@ def __rmatmul__( # type: ignore return FactoredMatrix(other @ self.A, self.B) else: return FactoredMatrix(other, self.AB) - elif isinstance(other, FactoredMatrix): - return other.A @ (other.B @ self) def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: """ @@ -148,8 +195,11 @@ def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: return self * scalar @property - def AB(self) -> Float[torch.Tensor, "*leading_dims ldim rdim"]: - """The product matrix - expensive to compute, and can consume a lot of GPU memory""" + def AB(self) -> Union[Float[torch.Tensor, "*leading_dims ldim rdim"], TensorLike]: + """The product matrix - expensive to compute, and can consume a lot of GPU memory. + + Returns a TensorLike when A or B is a non-Tensor TensorLike duck type. + """ return self.A @ self.B @property From 26d51a2b553105e9a392a4b08d5c47be0f524cb9 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Wed, 22 Apr 2026 07:50:56 -0700 Subject: [PATCH 11/21] NanoGPT Conversation did not handle case when there were no biases in model (#629) * Update convert_nanogpt_weights to have attention mask and handle case when there is no bias. Signed-off-by: Dashiell Stander * ran format * Make beartyping dependency more forgiving Signed-off-by: Dashiell Stander * generated lock file * Added an error if cfg.d_mlp is None --------- Signed-off-by: Dashiell Stander Co-authored-by: Bryce Meyer Co-authored-by: jlarson4 --- .../pretrained/weight_conversions/nanogpt.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/transformer_lens/pretrained/weight_conversions/nanogpt.py b/transformer_lens/pretrained/weight_conversions/nanogpt.py index 0a138cdd0..235575861 100644 --- a/transformer_lens/pretrained/weight_conversions/nanogpt.py +++ b/transformer_lens/pretrained/weight_conversions/nanogpt.py @@ -29,6 +29,8 @@ def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): if "transformer.ln_f.bias" in old_state_dict: bias = True new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"] + else: + new_state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) for layer in range(cfg.n_layers): layer_key = f"transformer.h.{layer}" @@ -43,6 +45,11 @@ def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): old_state_dict[f"{layer_key}.ln_2.weight"] ) + new_state_dict[f"blocks.{layer}.attn.mask"] = torch.tril( + torch.ones((cfg.n_ctx, cfg.n_ctx)).bool() + ) + new_state_dict[f"blocks.{layer}.attn.IGNORE"] = torch.tensor(-torch.inf) + W = old_state_dict[f"{layer_key}.attn.c_attn.weight"] W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) @@ -84,5 +91,22 @@ def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[ f"{layer_key}.attn.c_proj.bias" ] + else: + if cfg.d_mlp is None: + raise ValueError( + "cfg.d_mlp must be set to convert nanoGPT weights for the no-bias case." + ) + new_state_dict[f"blocks.{layer}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + new_state_dict[f"blocks.{layer}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + new_state_dict[f"blocks.{layer}.attn.b_Q"] = torch.zeros( + (cfg.n_heads, cfg.d_head), dtype=cfg.dtype + ) + new_state_dict[f"blocks.{layer}.attn.b_K"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + new_state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + new_state_dict[f"blocks.{layer}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) return new_state_dict From 2e89f7fb91f2782234203895ee438ff20ff1fa1d Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Wed, 22 Apr 2026 12:57:55 -0500 Subject: [PATCH 12/21] fixed batched generation on run_with_cache and run_with_hooks on transformerbridge (#1265) --- .../model_bridge/test_run_with_cache_batch.py | 80 +++++++++++++++++++ transformer_lens/model_bridge/bridge.py | 51 +++++++++++- 2 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 tests/acceptance/model_bridge/test_run_with_cache_batch.py diff --git a/tests/acceptance/model_bridge/test_run_with_cache_batch.py b/tests/acceptance/model_bridge/test_run_with_cache_batch.py new file mode 100644 index 000000000..9e163e525 --- /dev/null +++ b/tests/acceptance/model_bridge/test_run_with_cache_batch.py @@ -0,0 +1,80 @@ +"""Tests that batched run_with_cache and run_with_hooks produce correct results. + +Without an attention mask, HF models attend to padding tokens and contaminate +both logits and cached activations for shorter sequences in a batch. These +tests guard against that regression. +""" + +import torch + + +def _last_real_token_idx(bridge, tokens): + """Find the index of the last real token for each sequence in a batch.""" + if bridge.tokenizer.pad_token_id is None: + return torch.full((tokens.shape[0],), tokens.shape[1] - 1) + # With left-padding, the last real token is always at position -1 + return torch.full((tokens.shape[0],), tokens.shape[1] - 1) + + +def test_run_with_cache_batch_matches_individual(gpt2_bridge): + """Batched run_with_cache logits at the last real token should match per-prompt runs.""" + prompts = [ + "Hello, my dog is cute", + "This is a much longer text. Hello, my cat is cute", + ] + + # Individual runs + individual_logits = [] + for p in prompts: + logits, _ = gpt2_bridge.run_with_cache(p) + individual_logits.append(logits[0, -1, :]) + + # Batched run + batched_logits, _ = gpt2_bridge.run_with_cache(prompts) + # With left-padding forced internally, position -1 is the last real token + for i in range(len(prompts)): + batched_last = batched_logits[i, -1, :] + assert torch.allclose( + individual_logits[i], batched_last, atol=1e-4 + ), f"Prompt {i} logit mismatch between individual and batched run_with_cache" + + +def test_run_with_hooks_batch_matches_individual(gpt2_bridge): + """Batched run_with_hooks should produce the same hook values as per-prompt runs + (for the last real token position of each sequence).""" + prompts = [ + "Hello, my dog is cute", + "This is a much longer text. Hello, my cat is cute", + ] + + # Capture resid_post at last layer for last token + captured_individual = [] + + def capture_individual(tensor, hook): + # Last token's residual + captured_individual.append(tensor[0, -1, :].detach().clone()) + + for p in prompts: + gpt2_bridge.run_with_hooks( + p, + fwd_hooks=[("blocks.11.hook_resid_post", capture_individual)], + ) + + # Batched run + captured_batched = [] + + def capture_batched(tensor, hook): + # For left-padded batch, last real token is at position -1 for all + for i in range(tensor.shape[0]): + captured_batched.append(tensor[i, -1, :].detach().clone()) + + gpt2_bridge.run_with_hooks( + prompts, + fwd_hooks=[("blocks.11.hook_resid_post", capture_batched)], + ) + + assert len(captured_individual) == len(captured_batched) == len(prompts) + for i in range(len(prompts)): + assert torch.allclose( + captured_individual[i], captured_batched[i], atol=1e-4 + ), f"Prompt {i} hook value mismatch between individual and batched run_with_hooks" diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 93778e014..cfac0b51c 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1534,6 +1534,16 @@ def forward( else: kwargs.pop("one_zero_attention_mask") + # Detect batched list input that will need padding. For this case we force + # left-padding internally and auto-compute attention_mask + position_ids + # (unless the caller passed them explicitly) so pad tokens don't contaminate + # attention or position embeddings. + _is_batched_list = ( + isinstance(input, list) + and len(input) > 1 + and not getattr(self.cfg, "is_audio_model", False) + ) + try: if isinstance(input, (str, list)): if getattr(self.cfg, "is_audio_model", False): @@ -1541,9 +1551,20 @@ def forward( "Audio models require tensor input (raw waveform), not text. " "Pass a torch.Tensor or use the input_values parameter." ) - input_ids = self.to_tokens( - input, prepend_bos=prepend_bos, padding_side=padding_side - ) + if _is_batched_list and padding_side is None: + # Force left-padding so real tokens are flush-right. + _orig_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + try: + input_ids = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + ) + finally: + self.tokenizer.padding_side = _orig_padding_side + else: + input_ids = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + ) else: input_ids = input @@ -1553,6 +1574,30 @@ def forward( isinstance(input_ids, torch.Tensor) and input_ids.is_floating_point() ) + # Auto-compute attention_mask + position_ids for batched list input + # when the caller didn't supply them. Matches HF generation convention. + if ( + _is_batched_list + and attention_mask is None + and self.tokenizer is not None + and self.tokenizer.pad_token_id is not None + and not _is_inputs_embeds + ): + _prev_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + try: + attention_mask = utils.get_attention_mask( + self.tokenizer, + input_ids, + prepend_bos=getattr(self.cfg, "default_prepend_bos", True), + ).to(self.cfg.device) + finally: + self.tokenizer.padding_side = _prev_side + if "position_ids" not in kwargs: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + kwargs["position_ids"] = position_ids + if attention_mask is not None: kwargs["attention_mask"] = attention_mask if kwargs.pop("use_past_kv_cache", False) or kwargs.get("use_cache", False): From bbc3ec7ab3c4d02d00067884ae33dbd4ceb70920 Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Thu, 23 Apr 2026 11:47:26 -0500 Subject: [PATCH 13/21] Added 1D tensor handling in line with HookedTransformer (#1266) --- .../compatibility/test_run_with_cache.py | 16 ++++++++++++++++ transformer_lens/model_bridge/bridge.py | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py b/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py index 7a2e6a169..f653de956 100644 --- a/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py +++ b/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py @@ -61,3 +61,19 @@ def hook_fn(acts, hook): f"TransformerBridge run_with_cache should match manual hooks. " f"Max difference: {cache_diff:.6f}" ) + + def test_run_with_cache_accepts_1d_tensor(self, gpt2_bridge_compat_no_processing): + """1D token tensors should be auto-promoted to [1, seq], matching HookedTransformer.""" + bridge_model = gpt2_bridge_compat_no_processing + + tokens_1d = torch.tensor([1, 2, 3]) + tokens_2d = tokens_1d.unsqueeze(0) + + logits_1d, cache_1d = bridge_model.run_with_cache(tokens_1d) + logits_2d, cache_2d = bridge_model.run_with_cache(tokens_2d) + + assert logits_1d.shape == logits_2d.shape + assert torch.allclose(logits_1d, logits_2d, atol=1e-5) + assert torch.allclose( + cache_1d["blocks.0.hook_mlp_out"], cache_2d["blocks.0.hook_mlp_out"], atol=1e-5 + ) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index cfac0b51c..c2bd5dbd5 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1567,6 +1567,15 @@ def forward( ) else: input_ids = input + # Promote 1D integer token tensors to 2D [batch=1, seq] to match + # HookedTransformer's contract. Float tensors (inputs_embeds, + # audio waveforms) are passed through unchanged. + if ( + isinstance(input_ids, torch.Tensor) + and input_ids.ndim == 1 + and not input_ids.is_floating_point() + ): + input_ids = input_ids.unsqueeze(0) # Detect inputs_embeds: if the tensor is floating point, it's pre-computed # embeddings (e.g., from multimodal models) rather than token IDs. From 3c5896ee1f4c57e96f8d68b1f1c9d9fb00351a1b Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Thu, 23 Apr 2026 20:30:45 -0500 Subject: [PATCH 14/21] Added n_ctx override to TransformerBridge (#1269) * Added n_ctx override to TransformerBridge * Prevent output of progress bars in T5 demo --- demos/T5.ipynb | 37 ++---- .../model_bridge/test_n_ctx_override.py | 117 ++++++++++++++++++ transformer_lens/model_bridge/bridge.py | 5 + .../model_bridge/sources/transformers.py | 89 ++++++++++++- 4 files changed, 222 insertions(+), 26 deletions(-) create mode 100644 tests/acceptance/model_bridge/test_n_ctx_override.py diff --git a/demos/T5.ipynb b/demos/T5.ipynb index 1d225da96..fd4bc319e 100644 --- a/demos/T5.ipynb +++ b/demos/T5.ipynb @@ -88,14 +88,14 @@ "generated token: \",\", token id: 6\n", "generated token: \"comment\", token id: 1670\n", "generated token: \"\", token id: 3\n", - "generated token: \"\u00eates\", token id: 6738\n", + "generated token: \"êtes\", token id: 6738\n", "generated token: \"-\", token id: 18\n", "generated token: \"vous\", token id: 3249\n", "generated token: \"\", token id: 3\n", "generated token: \"?\", token id: 58\n", "generated token: \"\", token id: 1\n", "translate English to French: Hello, how are you? \n", - " Bonjour, comment \u00eates-vous?\n" + " Bonjour, comment êtes-vous?\n" ] } ], @@ -206,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2026-03-05T18:28:00.478310Z", @@ -215,21 +215,8 @@ "shell.execute_reply": "2026-03-05T18:28:00.629766Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Hallo, magst du Bananen?\n" - ] - } - ], - "source": [ - "prompt=\"translate English to German: Hello, do you like bananas?\"\n", - "\n", - "output = model.generate(prompt, do_sample=False, max_new_tokens=20)\n", - "print(output)" - ] + "outputs": [], + "source": "prompt=\"translate English to German: Hello, do you like bananas?\"\n\noutput = model.generate(prompt, do_sample=False, max_new_tokens=20, verbose=False)\nprint(output)" }, { "cell_type": "markdown", @@ -928,7 +915,7 @@ "outputs": [], "source": [ "encoder_attn_pattern = cache[\"encoder_blocks.0.attn.hook_pattern\"]\n", - "input_str_tokens = [w.lstrip(\"\u2581\") for w in tokenizer.convert_ids_to_tokens(input_ids[0])]" + "input_str_tokens = [w.lstrip(\"▁\") for w in tokenizer.convert_ids_to_tokens(input_ids[0])]" ] }, { @@ -993,14 +980,14 @@ "data": { "text/plain": [ "['',\n", - " '\u2581Bonjour',\n", + " '▁Bonjour',\n", " ',',\n", - " '\u2581comment',\n", - " '\u2581',\n", - " '\u00eates',\n", + " '▁comment',\n", + " '▁',\n", + " 'êtes',\n", " '-',\n", " 'vous',\n", - " '\u2581',\n", + " '▁',\n", " '?',\n", " '']" ] @@ -1143,4 +1130,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/tests/acceptance/model_bridge/test_n_ctx_override.py b/tests/acceptance/model_bridge/test_n_ctx_override.py new file mode 100644 index 000000000..be35474e6 --- /dev/null +++ b/tests/acceptance/model_bridge/test_n_ctx_override.py @@ -0,0 +1,117 @@ +"""Tests for the n_ctx override parameter on TransformerBridge.boot_transformers(). + +Uses load_weights=False so we can verify config plumbing without fighting HF's +weight-loading checks. Models with learned positional embeddings (e.g. GPT-2) +cannot have their n_ctx reduced at weight-load time — only rotary models can +freely resize. These tests verify the config is written correctly; users are +responsible for choosing n_ctx values their model supports. +""" + +import logging + +import pytest + +from transformer_lens.model_bridge import TransformerBridge + + +def test_n_ctx_override_writes_to_correct_hf_field(): + """For GPT-2 the field is n_positions — overriding n_ctx should update it.""" + bridge = TransformerBridge.boot_transformers( + "gpt2", device="cpu", n_ctx=256, load_weights=False + ) + assert bridge.cfg.n_ctx == 256 + assert bridge.original_model.config.n_positions == 256 + + +def test_n_ctx_default_uses_model_max(): + """Without an override, cfg.n_ctx reflects the HF config's value.""" + bridge = TransformerBridge.boot_transformers("gpt2", device="cpu", load_weights=False) + # GPT-2's n_positions default is 1024 + assert bridge.cfg.n_ctx == 1024 + + +def test_n_ctx_warns_when_above_default(caplog): + """Overriding n_ctx above the model default should emit a logging.warning.""" + with caplog.at_level(logging.WARNING): + TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=2048, load_weights=False) + assert any( + "larger than the model's default context length" in rec.message for rec in caplog.records + ) + + +def test_n_ctx_combined_with_hf_config_overrides(): + """Explicit n_ctx should take precedence over hf_config_overrides for that field.""" + bridge = TransformerBridge.boot_transformers( + "gpt2", + device="cpu", + n_ctx=256, + hf_config_overrides={"n_positions": 512}, # should be overridden by n_ctx=256 + load_weights=False, + ) + assert bridge.cfg.n_ctx == 256 + + +# --- Coverage for code-review items #2, #4, #5, #7 --- + + +def test_n_ctx_zero_raises_value_error(): + """#2: n_ctx must be positive; zero should raise ValueError.""" + with pytest.raises(ValueError, match="positive integer"): + TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=0, load_weights=False) + + +def test_n_ctx_negative_raises_value_error(): + """#2: n_ctx must be positive; negative should raise ValueError.""" + with pytest.raises(ValueError, match="positive integer"): + TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=-1, load_weights=False) + + +def test_n_ctx_conflict_with_hf_config_overrides_warns(caplog): + """#4: When both n_ctx and the same hf_config_overrides field are set with different values, + a warning should be emitted explaining that n_ctx wins.""" + with caplog.at_level(logging.WARNING): + TransformerBridge.boot_transformers( + "gpt2", + device="cpu", + n_ctx=256, + hf_config_overrides={"n_positions": 512}, + load_weights=False, + ) + assert any( + "Both n_ctx=256 and hf_config_overrides['n_positions']" in rec.message + and "takes precedence" in rec.message + for rec in caplog.records + ) + + +def test_n_ctx_no_conflict_when_values_match(caplog): + """#4: If n_ctx and hf_config_overrides agree on the value, no conflict warning is emitted.""" + with caplog.at_level(logging.WARNING): + TransformerBridge.boot_transformers( + "gpt2", + device="cpu", + n_ctx=256, + hf_config_overrides={"n_positions": 256}, # same as n_ctx + load_weights=False, + ) + assert not any("takes precedence" in rec.message for rec in caplog.records) + + +def test_n_ctx_shrink_with_load_weights_gives_clear_error(): + """#5: Shrinking a learned-pos-embed model's n_ctx at weight-load time should raise + with a message explaining the cause and suggesting alternatives.""" + with pytest.raises(RuntimeError) as exc_info: + TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=256, load_weights=True) + err = str(exc_info.value) + assert "n_ctx=256" in err + assert "learned positional embeddings" in err or "load_weights=False" in err + + +def test_n_ctx_override_verified_on_loaded_model(): + """#7: After load, the override should be visible on hf_model.config so users + can trust that the longer/shorter context is actually in effect.""" + bridge = TransformerBridge.boot_transformers( + "gpt2", device="cpu", n_ctx=2048, load_weights=False + ) + # The override persisted through model construction + assert bridge.original_model.config.n_positions == 2048 diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index c2bd5dbd5..7930ac08a 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -167,6 +167,7 @@ def boot_transformers( trust_remote_code: bool = False, model_class: Optional[type] = None, hf_model: Optional[Any] = None, + n_ctx: Optional[int] = None, ) -> "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). @@ -183,6 +184,9 @@ def boot_transformers( hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). When provided, load_weights is ignored. + n_ctx: Optional context length override. Writes to the appropriate HF config field + for this model automatically (callers don't need to know the field name). + Warns if larger than the model's default context length. Returns: The bridge to the loaded model. @@ -199,6 +203,7 @@ def boot_transformers( trust_remote_code=trust_remote_code, model_class=model_class, hf_model=hf_model, + n_ctx=n_ctx, ) @property diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 0169da4dc..99b90a968 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -286,6 +286,7 @@ def boot( trust_remote_code: bool = False, model_class: Any | None = None, hf_model: Any | None = None, + n_ctx: int | None = None, ) -> TransformerBridge: """Boot a model from HuggingFace. @@ -302,6 +303,11 @@ def boot( hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). When provided, load_weights is ignored. + n_ctx: Optional context length override. The bridge normally uses the model's documented + max context from the HF config. Setting this writes to whichever HF field the model + uses (n_positions / max_position_embeddings / etc.), so callers don't need to know + the field name. If larger than the model's default, a warning is emitted — quality + may degrade past the trained length for rotary models. Returns: The bridge to the loaded model. @@ -323,6 +329,54 @@ def boot( trust_remote_code=trust_remote_code, token=_hf_token, ) + _n_ctx_field: str | None = None + if n_ctx is not None: + # Validation (#2): reject non-positive values before doing anything else. + if n_ctx <= 0: + raise ValueError(f"n_ctx must be a positive integer, got n_ctx={n_ctx}.") + # Resolve n_ctx to whichever HF config field this model uses. Mirrors + # the order in map_default_transformer_lens_config so the TL config + # derivation picks up the override. + for _field in ( + "n_positions", + "max_position_embeddings", + "max_context_length", + "max_length", + "seq_length", + ): + if hasattr(hf_config, _field): + _n_ctx_field = _field + break + if _n_ctx_field is None: + raise ValueError( + f"Cannot apply n_ctx={n_ctx}: no recognized context-length field on " + f"HF config for {model_name}. Use hf_config_overrides instead." + ) + _default_n_ctx = getattr(hf_config, _n_ctx_field) + if _default_n_ctx is not None and n_ctx > _default_n_ctx: + logging.warning( + "Setting n_ctx=%d which is larger than the model's default " + "context length of %d. The model was not trained on sequences " + "this long and may produce unreliable results (especially for " + "rotary models without RoPE scaling).", + n_ctx, + _default_n_ctx, + ) + # Conflict detection (#4): warn if the caller also set the same field + # via hf_config_overrides — explicit n_ctx wins but users should know. + if hf_config_overrides and _n_ctx_field in hf_config_overrides: + _conflicting_value = hf_config_overrides[_n_ctx_field] + if _conflicting_value != n_ctx: + logging.warning( + "Both n_ctx=%d and hf_config_overrides['%s']=%s were provided. " + "The explicit n_ctx takes precedence.", + n_ctx, + _n_ctx_field, + _conflicting_value, + ) + # Explicit n_ctx wins over hf_config_overrides for the resolved field. + hf_config_overrides = dict(hf_config_overrides or {}) + hf_config_overrides[_n_ctx_field] = n_ctx if hf_config_overrides: hf_config.__dict__.update(hf_config_overrides) tl_config = map_default_transformer_lens_config(hf_config) @@ -409,13 +463,46 @@ def boot( with contextlib.redirect_stdout(None): hf_model = model_class.from_config(hf_config, **from_config_kwargs) else: - hf_model = model_class.from_pretrained(model_name, **model_kwargs) + try: + hf_model = model_class.from_pretrained(model_name, **model_kwargs) + except RuntimeError as e: + # #5: HF refuses to load when positional-weight shapes don't match. + # If the user requested an n_ctx that conflicts with the saved weights + # (common for learned-pos-embed models like GPT-2), re-raise with a + # clearer message pointing them at the likely cause. + if n_ctx is not None and "ignore_mismatched_sizes" in str(e): + raise RuntimeError( + f"Failed to load {model_name} with n_ctx={n_ctx}: the pretrained " + f"weights' positional-embedding shape does not match the requested " + f"context length. This affects models with learned positional " + f"embeddings (e.g. GPT-2, OPT). Options: (1) use the model's " + f"default n_ctx, (2) pass load_weights=False if you only need " + f"config inspection, or (3) choose a rotary-embedding model " + f"(e.g. Llama, Mistral) which supports n_ctx changes without " + f"weight mismatch." + ) from e + raise if device is not None: hf_model = hf_model.to(device) # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq) for param in hf_model.parameters(): if param.is_floating_point() and param.dtype != dtype: param.data = param.data.to(dtype=dtype) + # #7: Verify the n_ctx override actually took effect on the loaded model. + # If HF's config class silently dropped or normalized the value, warn so + # the user doesn't get misled into thinking longer sequences are supported. + if n_ctx is not None and _n_ctx_field is not None and hf_model is not None: + _actual = getattr(hf_model.config, _n_ctx_field, None) + if _actual != n_ctx: + logging.warning( + "n_ctx=%d was requested but hf_model.config.%s=%s after load. " + "The override may not have taken effect; the model may not " + "accept sequences longer than %s.", + n_ctx, + _n_ctx_field, + _actual, + _actual, + ) adapter.prepare_model(hf_model) tokenizer = tokenizer default_padding_side = getattr(adapter.cfg, "default_padding_side", None) From e4222a6ac66f79fcc201e28158b639e8f288630b Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Thu, 23 Apr 2026 22:06:35 -0500 Subject: [PATCH 15/21] Feature/generate stream on bridge (#1268) * adds HookedTransformer.generate_stream() * fixes mypy errors * Adjusted for TransformerLens 3 changes * Initial bridge generate stream * TransformerBridge Generate Stream --------- Co-authored-by: anthonyduong Co-authored-by: Bryce Meyer --- .../model_bridge/test_generate_stream.py | 97 +++ transformer_lens/model_bridge/bridge.py | 589 ++++++++++++------ 2 files changed, 504 insertions(+), 182 deletions(-) create mode 100644 tests/acceptance/model_bridge/test_generate_stream.py diff --git a/tests/acceptance/model_bridge/test_generate_stream.py b/tests/acceptance/model_bridge/test_generate_stream.py new file mode 100644 index 000000000..70e7a456a --- /dev/null +++ b/tests/acceptance/model_bridge/test_generate_stream.py @@ -0,0 +1,97 @@ +"""Tests for TransformerBridge.generate_stream().""" + +import torch + + +def test_stream_matches_generate(gpt2_bridge): + """Concatenated stream output should match generate() for the same prompt.""" + prompt = "The future of AI" + # Get generate() output as string + expected_text = gpt2_bridge.generate(prompt, max_new_tokens=10, do_sample=False, verbose=False) + assert isinstance(expected_text, str) + + # Stream as tokens so we can concatenate and compare + chunks = list( + gpt2_bridge.generate_stream( + prompt, + max_new_tokens=10, + max_tokens_per_yield=3, + do_sample=False, + verbose=False, + return_type="tokens", + ) + ) + assert len(chunks) >= 1 + + # First chunk = input + first tokens, subsequent = new tokens only. + all_tokens = chunks[0] + for chunk in chunks[1:]: + all_tokens = torch.cat([all_tokens, chunk], dim=-1) + + streamed_text = gpt2_bridge.tokenizer.decode(all_tokens[0], skip_special_tokens=True) + assert ( + expected_text == streamed_text + ), f"Stream output mismatch:\n generate: {expected_text!r}\n stream: {streamed_text!r}" + + +def test_stream_yields_progressively(gpt2_bridge): + """Multiple yields should occur with small max_tokens_per_yield.""" + chunks = list( + gpt2_bridge.generate_stream( + "Hello world", + max_new_tokens=10, + max_tokens_per_yield=3, + do_sample=False, + verbose=False, + return_type="tokens", + ) + ) + assert len(chunks) > 1, f"Expected multiple yields, got {len(chunks)}" + + +def test_stream_single_prompt(gpt2_bridge): + """Basic single-string streaming should produce output.""" + results = list( + gpt2_bridge.generate_stream( + "Test", + max_new_tokens=5, + do_sample=False, + verbose=False, + return_type="tokens", + ) + ) + assert len(results) >= 1 + assert results[0].shape[0] == 1 # batch=1 + assert results[0].shape[1] > 1 # has at least input + 1 generated token + + +def test_stream_stops_at_eos(gpt2_bridge): + """Streaming should respect stop_at_eos.""" + results = list( + gpt2_bridge.generate_stream( + "Test", + max_new_tokens=200, + max_tokens_per_yield=5, + stop_at_eos=True, + do_sample=False, + verbose=False, + return_type="tokens", + ) + ) + total_tokens = sum(r.shape[1] for r in results) + assert total_tokens < 210 + + +def test_stream_returns_strings(gpt2_bridge): + """With return_type='str', yields should be strings.""" + results = list( + gpt2_bridge.generate_stream( + "Hello", + max_new_tokens=5, + do_sample=False, + verbose=False, + return_type="str", + ) + ) + assert len(results) >= 1 + assert all(isinstance(r, str) for r in results) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 7930ac08a..9e24fa9fc 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -6,6 +6,7 @@ import logging import re import warnings +from collections.abc import Generator from contextlib import contextmanager from functools import lru_cache from typing import ( @@ -26,6 +27,7 @@ import einops import numpy as np import torch +import tqdm from torch import nn from transformer_lens import utilities as utils @@ -2136,6 +2138,190 @@ def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn): for hook_point, name in added_hooks: hook_point.remove_hooks() + def _generate_tokens( + self, + current_tokens: torch.Tensor, + input_tokens: torch.Tensor, + batch_size: int, + *, + max_new_tokens: int, + do_sample: bool, + top_k: Optional[int], + top_p: Optional[float], + temperature: float, + freq_penalty: float, + repetition_penalty: float, + stop_at_eos: bool, + stop_tokens: List[int], + eos_token_for_padding: int, + finished_sequences: torch.Tensor, + use_past_kv_cache: bool, + use_stateful_cache: bool, + mamba_cache: Any, + mamba_conv_kernel: int, + is_encoder_decoder: bool, + _is_batched_list: bool, + _generate_from_embeds: bool, + encoder_input: Optional[torch.Tensor], + decoder_tokens: Optional[torch.Tensor], + generated_token_ids: Optional[List[torch.Tensor]], + pixel_values: Optional[torch.Tensor], + multimodal_kwargs: Dict[str, Any], + verbose: bool, + ) -> Generator[Tuple[torch.Tensor, torch.Tensor, bool], None, None]: + """Core generation loop. Yields (sampled_tokens, final_logits, all_finished) per step. + + Owns the forward pass, sampling, EOS handling, token accumulation, and + KV cache management. Callers are responsible for try/finally cleanup of + ``_capture_hf_cache``. + """ + _hf_kv_cache = None + + for gen_step_idx in tqdm.tqdm(range(max_new_tokens), disable=not verbose): + with torch.no_grad(): + if is_encoder_decoder: + logits = self( + encoder_input, + return_type="logits", + decoder_input=decoder_tokens, + ) + else: + forward_kwargs: Dict[str, Any] = {} + # Compute attention mask and position_ids for batched + # inputs with padding. + if ( + _is_batched_list + and self.tokenizer is not None + and self.tokenizer.pad_token_id is not None + ): + _prev_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + attn_mask = utils.get_attention_mask( + self.tokenizer, + current_tokens, + prepend_bos=getattr(self.cfg, "default_prepend_bos", True), + ).to(self.cfg.device) + self.tokenizer.padding_side = _prev_side + forward_kwargs["attention_mask"] = attn_mask + position_ids = attn_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attn_mask == 0, 1) + forward_kwargs["position_ids"] = position_ids + if gen_step_idx == 0: + if pixel_values is not None: + forward_kwargs["pixel_values"] = pixel_values + if multimodal_kwargs: + forward_kwargs.update(multimodal_kwargs) + if use_stateful_cache: + forward_kwargs["cache_params"] = mamba_cache + forward_kwargs["use_cache"] = True + if gen_step_idx == 0: + cache_position = torch.arange( + 0, mamba_conv_kernel, device=self.cfg.device + ) + forward_kwargs["cache_position"] = cache_position + logits = self( + current_tokens, + return_type="logits", + **forward_kwargs, + ) + else: + input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1 + cache_position = torch.tensor([input_seq_pos], device=self.cfg.device) + forward_kwargs["cache_position"] = cache_position + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] + logits = self( + current_tokens[:, -1:], + return_type="logits", + **forward_kwargs, + ) + elif use_past_kv_cache: + forward_kwargs["use_cache"] = True + if _hf_kv_cache is not None: + forward_kwargs["past_key_values"] = _hf_kv_cache + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] + logits = self( + current_tokens[:, -1:], + return_type="logits", + **forward_kwargs, + ) + else: + logits = self( + current_tokens, + return_type="logits", + **forward_kwargs, + ) + else: + logits = self(current_tokens, return_type="logits", **forward_kwargs) + if use_past_kv_cache and hasattr(self, "_last_hf_cache"): + _hf_kv_cache = self._last_hf_cache or _hf_kv_cache + del self._last_hf_cache + final_logits = logits[:, -1, :] + + # Sample next token + penalty_tokens = ( + torch.stack(generated_token_ids, dim=1) + if _generate_from_embeds and generated_token_ids + else None + ) + if do_sample: + sampled_tokens = utils.sample_logits( + final_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + repetition_penalty=repetition_penalty, + tokens=penalty_tokens + if _generate_from_embeds + else (decoder_tokens if is_encoder_decoder else current_tokens), + ).to(self.cfg.device) + else: + sampled_tokens = utils.sample_logits( + final_logits, + temperature=0.0, + repetition_penalty=repetition_penalty, + tokens=penalty_tokens + if _generate_from_embeds + else (decoder_tokens if is_encoder_decoder else current_tokens), + ).to(self.cfg.device) + + # Handle EOS + if stop_at_eos: + sampled_tokens[finished_sequences] = eos_token_for_padding + finished_sequences.logical_or_( + torch.isin( + sampled_tokens.to(self.cfg.device), + torch.tensor(stop_tokens).to(self.cfg.device), + ) + ) + + # Update token sequences + if is_encoder_decoder: + assert decoder_tokens is not None + decoder_tokens = torch.cat([decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1) + elif _generate_from_embeds: + assert generated_token_ids is not None + generated_token_ids.append(sampled_tokens) + embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator] + assert embed_fn is not None + new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype) + current_tokens = torch.cat([current_tokens, new_embed], dim=1) + else: + current_tokens = torch.cat([current_tokens, sampled_tokens.unsqueeze(1)], dim=1) + + all_finished = bool(stop_at_eos and finished_sequences.all().item()) + + yield sampled_tokens, final_logits, all_finished + + if all_finished: + return + def generate( self, input: Union[str, List[str], torch.Tensor] = "", @@ -2355,188 +2541,41 @@ def generate( ) try: - for gen_step_idx in range(max_new_tokens): - # Get logits for next token - with torch.no_grad(): - if is_encoder_decoder: - logits = self( - encoder_input, - return_type="logits", - decoder_input=decoder_tokens, - ) - else: - forward_kwargs: Dict[str, Any] = {} - # Compute attention mask and position_ids for batched - # inputs with padding. HF models default to all-ones - # when no mask is given, which ignores padding tokens. - if ( - _is_batched_list - and self.tokenizer is not None - and self.tokenizer.pad_token_id is not None - ): - # Temp-swap to "left" so get_attention_mask scans - # for leading pads (matching the left-padded tokens). - _prev_side = self.tokenizer.padding_side - self.tokenizer.padding_side = "left" - attn_mask = utils.get_attention_mask( - self.tokenizer, - current_tokens, - prepend_bos=getattr(self.cfg, "default_prepend_bos", True), - ).to(self.cfg.device) - self.tokenizer.padding_side = _prev_side - forward_kwargs["attention_mask"] = attn_mask - # Adjust position_ids for left-padding so pad - # tokens don't consume real position embeddings. - position_ids = attn_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attn_mask == 0, 1) - forward_kwargs["position_ids"] = position_ids - # Pass multimodal inputs only on the first step — the vision - # encoder processes the image once, embedding it into the - # token sequence. This includes pixel_values plus any extra - # processor outputs (e.g. image_sizes for LlavaNext). - if gen_step_idx == 0: - if pixel_values is not None: - forward_kwargs["pixel_values"] = pixel_values - if multimodal_kwargs: - forward_kwargs.update(multimodal_kwargs) - if use_stateful_cache: - # Prefill sends arange(conv_kernel) (which both - # Mamba-1's length check and Mamba-2's value check - # accept as "not decode"). Decode sends the input - # token's actual sequence position — a fixed value - # above conv_kernel-1 silently picks the wrong - # slot for short prompts (see - # test_greedy_matches_hf_across_prompt_lengths). - # conv1d hooks fire only on prefill; HF bypasses - # the conv1d module on decode (see DepthwiseConv1DBridge). - forward_kwargs["cache_params"] = mamba_cache - forward_kwargs["use_cache"] = True - if gen_step_idx == 0: - cache_position = torch.arange( - 0, mamba_conv_kernel, device=self.cfg.device - ) - forward_kwargs["cache_position"] = cache_position - logits = self( - current_tokens, - return_type="logits", - **forward_kwargs, - ) - else: - # Token generated at step N-1 lives at - # sequence position prompt_len + gen_step_idx - 1 - input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1 - cache_position = torch.tensor( - [input_seq_pos], device=self.cfg.device - ) - forward_kwargs["cache_position"] = cache_position - if "position_ids" in forward_kwargs: - forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ - :, -1: - ] - logits = self( - current_tokens[:, -1:], - return_type="logits", - **forward_kwargs, - ) - elif use_past_kv_cache: - forward_kwargs["use_cache"] = True - if _hf_kv_cache is not None: - # Cached step: pass only the last token + cache - forward_kwargs["past_key_values"] = _hf_kv_cache - if "position_ids" in forward_kwargs: - forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ - :, -1: - ] - logits = self( - current_tokens[:, -1:], - return_type="logits", - **forward_kwargs, - ) - else: - # Step 0: full sequence, cache gets populated - logits = self( - current_tokens, - return_type="logits", - **forward_kwargs, - ) - else: - # No cache: full sequence every step - logits = self(current_tokens, return_type="logits", **forward_kwargs) - # Capture HF cache from forward() for next step. - if use_past_kv_cache and hasattr(self, "_last_hf_cache"): - _hf_kv_cache = self._last_hf_cache or _hf_kv_cache - del self._last_hf_cache - final_logits = logits[:, -1, :] - - # Collect logits if requested - if logits_seq_list is not None: - logits_seq_list.append(final_logits.clone()) - - # Sample next token - # For inputs_embeds, we can't pass the embeddings to freq/rep penalty, - # so use the generated_token_ids for penalty tracking - penalty_tokens = ( - torch.stack(generated_token_ids, dim=1) - if _generate_from_embeds and generated_token_ids - else None - ) - if do_sample: - sampled_tokens = utils.sample_logits( - final_logits, - top_k=top_k, - top_p=top_p, - temperature=temperature, - freq_penalty=freq_penalty, - repetition_penalty=repetition_penalty, - tokens=penalty_tokens - if _generate_from_embeds - else (decoder_tokens if is_encoder_decoder else current_tokens), - ).to(self.cfg.device) - else: - sampled_tokens = utils.sample_logits( - final_logits, - temperature=0.0, - repetition_penalty=repetition_penalty, - tokens=penalty_tokens - if _generate_from_embeds - else (decoder_tokens if is_encoder_decoder else current_tokens), - ).to(self.cfg.device) - - sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) - - # Handle EOS tokens for finished sequences - if stop_at_eos: - sampled_tokens[finished_sequences] = eos_token_for_padding - finished_sequences.logical_or_( - torch.isin( - sampled_tokens.to(self.cfg.device), - torch.tensor(stop_tokens).to(self.cfg.device), - ) - ) - - # Append sampled token to current sequence - if is_encoder_decoder: - decoder_tokens = torch.cat( - [decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1 - ) - elif _generate_from_embeds: - # For inputs_embeds: get the embedding of the new token and append - generated_token_ids.append(sampled_tokens) - embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator] - assert embed_fn is not None - new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype) - current_tokens = torch.cat([current_tokens, new_embed], dim=1) - else: - current_tokens = torch.cat( - [current_tokens, sampled_tokens.unsqueeze(1)], dim=1 - ) - - # Early stopping if all sequences finished - if stop_at_eos and finished_sequences.all(): - break + for sampled_tokens, final_logits, all_finished in self._generate_tokens( + current_tokens, + input_tokens, + batch_size, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + repetition_penalty=repetition_penalty, + stop_at_eos=stop_at_eos, + stop_tokens=stop_tokens, + eos_token_for_padding=eos_token_for_padding, + finished_sequences=finished_sequences, + use_past_kv_cache=use_past_kv_cache, + use_stateful_cache=use_stateful_cache, + mamba_cache=mamba_cache, + mamba_conv_kernel=mamba_conv_kernel, + is_encoder_decoder=is_encoder_decoder, + _is_batched_list=_is_batched_list, + _generate_from_embeds=_generate_from_embeds, + encoder_input=encoder_input if is_encoder_decoder else None, + decoder_tokens=decoder_tokens if is_encoder_decoder else None, + generated_token_ids=generated_token_ids if _generate_from_embeds else None, + pixel_values=pixel_values, + multimodal_kwargs=multimodal_kwargs if multimodal_kwargs else {}, + verbose=verbose, + ): + sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) + if logits_seq_list is not None: + logits_seq_list.append(final_logits.clone()) + if all_finished: + break finally: - # Clean up generate-only state even if an exception occurs, - # so _capture_hf_cache doesn't leak into subsequent forward() calls. self._capture_hf_cache = False if hasattr(self, "_last_hf_cache"): del self._last_hf_cache @@ -2544,7 +2583,8 @@ def generate( # Concatenate all sampled tokens sampled_tokens = torch.cat(sampled_tokens_list, dim=1) if is_encoder_decoder: - output_tokens = decoder_tokens + # Reconstruct full decoder sequence: start token + generated tokens + output_tokens = torch.cat([decoder_tokens[:, :1], sampled_tokens], dim=1) elif _generate_from_embeds: # For inputs_embeds, we only have the generated token IDs (no input token IDs) output_tokens = sampled_tokens @@ -2591,6 +2631,191 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ... else: # return_type == "tokens" return output_tokens + @torch.no_grad() + def generate_stream( + self, + input: Union[str, List[str], torch.Tensor] = "", + max_new_tokens: int = 10, + max_tokens_per_yield: int = 25, + stop_at_eos: bool = True, + eos_token_id: Optional[int] = None, + do_sample: bool = True, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: float = 1.0, + freq_penalty: float = 0.0, + repetition_penalty: float = 1.0, + use_past_kv_cache: bool = True, + prepend_bos: Optional[bool] = None, + padding_side: Optional[str] = None, + return_type: Optional[str] = "input", + verbose: bool = True, + ) -> Generator[Union[torch.Tensor, str], None, None]: + """Stream tokens from the model as they are generated. + + Yields batches of tokens progressively during generation rather than + waiting for the entire sequence. Uses the same core loop as generate(). + + Args: + input: Text string, list of strings, or tensor of tokens. + max_new_tokens: Maximum number of tokens to generate. + max_tokens_per_yield: Yield accumulated tokens every this many steps. + stop_at_eos: If True, stop when eos_token is produced. + eos_token_id: Token ID(s) for end of sentence. Defaults to tokenizer's. + do_sample: If True, sample; otherwise greedy. + top_k: Top-k sampling. None means no filtering. + top_p: Nucleus sampling threshold. + temperature: Sampling temperature. + freq_penalty: Frequency penalty for previous tokens. + repetition_penalty: HF-style repetition penalty (>1.0 discourages repeats). + use_past_kv_cache: Use KV caching for faster generation. + prepend_bos: Not applied (API compatibility). See generate() docstring. + padding_side: Which side to pad for batched list inputs. Left-padding + is forced internally for batched generation. + return_type: 'input' (match input type), 'str', or 'tokens'. + verbose: Show progress bar. + + Yields: + Token tensors [batch, seq_len] or strings, accumulated up to + max_tokens_per_yield tokens between yields. First yield includes + the input tokens; subsequent yields contain only new tokens. + """ + if prepend_bos is not None: + warnings.warn( + "prepend_bos is ignored during TransformerBridge.generate_stream(). " + "The HF model expects tokens with the tokenizer's default BOS handling.", + stacklevel=2, + ) + + # --- Input parsing (mirrors generate()) --- + _is_batched_list = isinstance(input, list) and len(input) > 1 + + if isinstance(input, str): + input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + input_type = "str" + elif isinstance(input, list): + if _is_batched_list: + _orig_ps = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + try: + input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + finally: + if _is_batched_list: + self.tokenizer.padding_side = _orig_ps + input_type = "list" + else: + input_tokens = input.to(self.cfg.device) + input_type = "tokens" + + if return_type == "input": + return_type = "str" if input_type in ["str", "list"] else "tokens" + + batch_size = input_tokens.shape[0] + + # --- EOS setup --- + stop_tokens: List[int] = [] + eos_token_for_padding = 0 + if stop_at_eos: + if eos_token_id is None: + assert ( + self.tokenizer.eos_token_id is not None + ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id" + eos_token_id = self.tokenizer.eos_token_id + if isinstance(eos_token_id, int): + stop_tokens = [eos_token_id] + eos_token_for_padding = eos_token_id + else: + stop_tokens = list(eos_token_id) + eos_token_for_padding = ( + self.tokenizer.eos_token_id + if self.tokenizer.eos_token_id is not None + else eos_token_id[0] + ) + + finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) + + # --- Cache setup --- + if use_past_kv_cache: + self._capture_hf_cache = True + + current_tokens = input_tokens.clone() + + # --- Streaming loop --- + # All yields are token tensors [batch, seq_len]. Each yield contains + # only the newly generated tokens since the previous yield (the first + # yield additionally prepends the input tokens for context). + accumulated_tokens: Optional[torch.Tensor] = None + tokens_since_last_yield = 0 + + def _maybe_decode( + tokens: torch.Tensor, + ) -> Union[torch.Tensor, str]: + if return_type == "str": + return self.tokenizer.decode(tokens[0], skip_special_tokens=True) + return tokens + + try: + for step_idx, (sampled_tokens, _, all_finished) in enumerate( + self._generate_tokens( + current_tokens, + input_tokens, + batch_size, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + repetition_penalty=repetition_penalty, + stop_at_eos=stop_at_eos, + stop_tokens=stop_tokens, + eos_token_for_padding=eos_token_for_padding, + finished_sequences=finished_sequences, + use_past_kv_cache=use_past_kv_cache, + use_stateful_cache=False, + mamba_cache=None, + mamba_conv_kernel=0, + is_encoder_decoder=False, + _is_batched_list=_is_batched_list, + _generate_from_embeds=False, + encoder_input=None, + decoder_tokens=None, + generated_token_ids=None, + pixel_values=None, + multimodal_kwargs={}, + verbose=verbose, + ) + ): + new_tokens = sampled_tokens.unsqueeze(-1) + + if step_idx == 0: + accumulated_tokens = torch.cat([input_tokens, new_tokens], dim=-1) + tokens_since_last_yield = accumulated_tokens.shape[1] + else: + if accumulated_tokens is None: + accumulated_tokens = new_tokens + else: + accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) + tokens_since_last_yield += 1 + + if tokens_since_last_yield >= max_tokens_per_yield: + yield _maybe_decode(accumulated_tokens) + tokens_since_last_yield = 0 + accumulated_tokens = None + + if all_finished: + if accumulated_tokens is not None: + yield _maybe_decode(accumulated_tokens) + break + + # Yield remainder after loop completes without break + if accumulated_tokens is not None: + yield _maybe_decode(accumulated_tokens) + finally: + self._capture_hf_cache = False + if hasattr(self, "_last_hf_cache"): + del self._last_hf_cache + def hf_generate( self, input: str | list[str] | torch.Tensor = "", From 1607ef20d5cf30672310d3c75a61124f15886353 Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Tue, 28 Apr 2026 14:14:10 -0700 Subject: [PATCH 16/21] Added warnings for users attempting to use MPS with Torch 2.8 (#1271) --- tests/unit/utilities/test_devices.py | 112 +++++++++++++++++++++++++- transformer_lens/utilities/devices.py | 77 +++++++++++++----- 2 files changed, 169 insertions(+), 20 deletions(-) diff --git a/tests/unit/utilities/test_devices.py b/tests/unit/utilities/test_devices.py index 5e1af5632..6ea3b4095 100644 --- a/tests/unit/utilities/test_devices.py +++ b/tests/unit/utilities/test_devices.py @@ -158,12 +158,14 @@ def test_move_to_and_update_config_print_details_false(): @pytest.fixture(autouse=True) def reset_mps_warned(): - """Reset the _mps_warned flag before each test.""" + """Reset the _mps_warned and _mps_broken_torch_warned flags before each test.""" import transformer_lens.utilities.devices as devices_module devices_module._mps_warned = False + devices_module._mps_broken_torch_warned = False yield devices_module._mps_warned = False + devices_module._mps_broken_torch_warned = False @patch.dict("os.environ", {}, clear=False) @@ -291,3 +293,111 @@ def test_warn_if_mps_active_when_torch_version_below_safe(): assert len(w) == 1 finally: devices_module._MPS_MIN_SAFE_TORCH_VERSION = original + + +# --- Known-broken-torch-on-MPS warning tests (issue #1062, torch 2.8.0) --- + + +@patch.dict("os.environ", {}, clear=False) +def test_warn_if_mps_warns_about_broken_torch_version(): + """When torch is in _MPS_BROKEN_TORCH_VERSIONS, warn_if_mps emits the broken-version warning.""" + import os + + import transformer_lens.utilities.devices as devices_module + + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=(2, 8), + ): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_mps("mps") + messages = [str(warning.message) for warning in w] + assert any( + "known MPS bug that produces silently incorrect results" in m for m in messages + ), f"Expected broken-torch warning in {messages}" + assert any("issues/1062" in m for m in messages) + + +@patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) +def test_warn_if_mps_broken_torch_warning_fires_even_when_opted_in(): + """The broken-torch warning must fire even with TRANSFORMERLENS_ALLOW_MPS=1, + because the bug produces silently wrong output regardless of opt-in.""" + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=(2, 8), + ): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_mps("mps") + messages = [str(warning.message) for warning in w] + assert any("known MPS bug" in m for m in messages) + + +@patch.dict("os.environ", {}, clear=False) +def test_warn_if_mps_no_broken_warning_on_safe_torch_version(): + """Non-broken torch versions should not emit the broken-torch warning.""" + import os + + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + for version in [(2, 7), (2, 9), (3, 0)]: + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=version, + ): + # Reset the broken-warn flag for each iteration + import transformer_lens.utilities.devices as devices_module + + devices_module._mps_broken_torch_warned = False + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_mps("mps") + messages = [str(warning.message) for warning in w] + assert not any( + "known MPS bug" in m for m in messages + ), f"Unexpected broken-torch warning on torch {version}: {messages}" + + +@patch.dict("os.environ", {}, clear=False) +def test_warn_if_mps_broken_warning_fires_only_once(): + """The broken-torch warning should only fire once per process.""" + import os + + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=(2, 8), + ): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_mps("mps") + warn_if_mps("mps") + warn_if_mps(torch.device("mps")) + broken_warnings = [warning for warning in w if "known MPS bug" in str(warning.message)] + assert len(broken_warnings) == 1 + + +def test_torch_mps_has_known_broken_bug_for_2_8(): + """_torch_mps_has_known_broken_bug should return True for torch 2.8.""" + from transformer_lens.utilities.devices import _torch_mps_has_known_broken_bug + + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=(2, 8), + ): + assert _torch_mps_has_known_broken_bug() is True + + +def test_torch_mps_has_known_broken_bug_false_for_other_versions(): + """_torch_mps_has_known_broken_bug should return False for non-broken torch versions.""" + from transformer_lens.utilities.devices import _torch_mps_has_known_broken_bug + + for version in [(2, 7), (2, 9), (3, 0)]: + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=version, + ): + assert ( + _torch_mps_has_known_broken_bug() is False + ), f"torch {version} incorrectly flagged as broken" diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index d1265f002..470646a41 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -24,12 +24,30 @@ # Bump this when a PyTorch release ships verified MPS fixes. _MPS_MIN_SAFE_TORCH_VERSION: tuple[int, ...] | None = None +# torch 2.8.0 on MPS has an upstream bug where torch.nn.functional.linear +# produces incorrect results for non-contiguous tensors. This silently +# corrupts generate() output and attention computations. Fixed in 2.9.0. +# See: https://github.com/pytorch/pytorch/issues/161640 +# See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 +_MPS_BROKEN_TORCH_VERSIONS: tuple[tuple[int, ...], ...] = ((2, 8),) + +_mps_broken_torch_warned = False + def _torch_version_tuple() -> tuple[int, ...]: """Parse torch.__version__ into a comparable tuple, ignoring pre-release suffixes.""" return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) +def _torch_mps_has_known_broken_bug() -> bool: + """True if the installed torch version has a known-broken MPS path. + + Distinct from the generic MPS-may-be-unreliable warning: these are specific, + upstream-fixed bugs where output is silently wrong regardless of opt-in. + """ + return _torch_version_tuple() in _MPS_BROKEN_TORCH_VERSIONS + + # --------------------------------------------------------------------------- # Device helpers # --------------------------------------------------------------------------- @@ -69,28 +87,49 @@ def warn_if_mps(device): Automatically suppressed when the installed PyTorch version meets or exceeds _MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet). + + Also emits a separate, stronger warning for known-broken torch versions on MPS + (see _MPS_BROKEN_TORCH_VERSIONS). This warning fires even when the user has + opted in via TRANSFORMERLENS_ALLOW_MPS=1, because the affected operations + produce silently wrong outputs regardless of opt-in. """ - global _mps_warned - if _mps_warned: - return + global _mps_warned, _mps_broken_torch_warned if isinstance(device, torch.device): device = device.type - if isinstance(device, str) and device == "mps": - if ( - _MPS_MIN_SAFE_TORCH_VERSION is not None - and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION - ): - return - if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1": - _mps_warned = True - warnings.warn( - "MPS backend may produce silently incorrect results (PyTorch " - f"{torch.__version__}). " - "Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. " - "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178", - UserWarning, - stacklevel=2, - ) + if not (isinstance(device, str) and device == "mps"): + return + + # Known-broken torch versions always warn (can't be opted-out of). + if _torch_mps_has_known_broken_bug() and not _mps_broken_torch_warned: + _mps_broken_torch_warned = True + warnings.warn( + f"PyTorch {torch.__version__} has a known MPS bug that produces " + "silently incorrect results (torch.nn.functional.linear on " + "non-contiguous tensors). This corrupts generate() output and " + "attention computations. Upgrade to torch >= 2.9.0. " + "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 " + "and https://github.com/pytorch/pytorch/issues/161640", + UserWarning, + stacklevel=2, + ) + + if _mps_warned: + return + if ( + _MPS_MIN_SAFE_TORCH_VERSION is not None + and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION + ): + return + if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1": + _mps_warned = True + warnings.warn( + "MPS backend may produce silently incorrect results (PyTorch " + f"{torch.__version__}). " + "Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. " + "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178", + UserWarning, + stacklevel=2, + ) # --------------------------------------------------------------------------- From a92a90a1da64a80e30802efdff7d544a7a718c78 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 28 Apr 2026 16:34:49 -0500 Subject: [PATCH 17/21] Documenting 3.1 features, adding additional context to the purpose of compatibility mode --- README.md | 2 +- docs/source/content/migrating_to_v3.md | 31 +++++++++++++++++++++---- docs/source/content/model_structure.md | 2 +- transformer_lens/model_bridge/bridge.py | 13 ++++++++--- 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 5f3c40417..7c79cc390 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") logits, activations = bridge.run_with_cache("Hello World") ``` -`TransformerBridge` is the recommended 3.0 path and supports 50+ architectures. The legacy `HookedTransformer.from_pretrained` API is still available through a compatibility layer but is deprecated - see the [Migrating to TransformerLens 3](https://TransformerLensOrg.github.io/TransformerLens/content/migrating_to_v3.html) guide for conversion recipes. +`TransformerBridge` is the recommended 3.0 path and supports 50+ architectures. By default it preserves raw HuggingFace weights – logits and activations match HF, *not* legacy `HookedTransformer` (which folds LayerNorm and centers weights by default). Call `bridge.enable_compatibility_mode()` after booting for HookedTransformer-equivalent numerics. The legacy `HookedTransformer.from_pretrained` API is still available but deprecated — see the [Migrating to TransformerLens 3](https://TransformerLensOrg.github.io/TransformerLens/content/migrating_to_v3.html) guide. ## Key Tutorials diff --git a/docs/source/content/migrating_to_v3.md b/docs/source/content/migrating_to_v3.md index 79546e3d4..fc47fceef 100644 --- a/docs/source/content/migrating_to_v3.md +++ b/docs/source/content/migrating_to_v3.md @@ -36,13 +36,14 @@ Weight-processing flags (`fold_ln`, `center_writing_weights`, `center_unembed`, ### Parameters that were removed -`n_devices`, `move_to_device`, `first_n_layers`, and `n_ctx` are not part of `boot_transformers`. If you relied on any of these, file an issue describing your use case — the right pattern for multi-GPU loads under the bridge is still being worked out. +`n_devices`, `move_to_device`, and `first_n_layers` are not part of `boot_transformers`. If you relied on any of these, file an issue describing your use case — the right pattern for multi-GPU loads under the bridge is still being worked out. ### Parameters that are new - `load_weights: bool = True` — set to `False` to construct the bridge with just the config (useful for shape-checking without paying the weight-load cost). - `trust_remote_code: bool = False` — pass through to HuggingFace for models that ship custom modeling code. - `hf_config_overrides: dict | None = None` — override specific fields of the HF config before the model is constructed. +- `n_ctx: int | None = None` — override the model's context length. The bridge writes to whichever HF config field this architecture uses (`n_positions` / `max_position_embeddings` / etc.) so callers don't need to know the field name. Warns if larger than the model's default. - `hf_model` / `model_class` — advanced: pass in a pre-loaded HF model or a specific model class. ## Weight processing is now opt-in @@ -74,6 +75,18 @@ bridge.enable_compatibility_mode( If you want no processing at all — the bridge's native default — you can skip `enable_compatibility_mode` entirely, or call it with `no_processing=True` if you still want the hook/component compatibility layer without the weight transforms. +### Will my numbers match HookedTransformer? + +| Computing | Without `enable_compatibility_mode` | With it | +| --- | --- | --- | +| Generated text, CE loss, argmax / top-k | Identical | Identical | +| Raw logits | Differ by per-row constant | Match | +| Logit lens, direct logit attribution | Differ | Match | +| KL divergence vs another model | Differ | Match | +| Residual-stream norms, cached `hook_resid_*` | Differ (grows with depth) | Match | + +Bottom-half analyses → call `enable_compatibility_mode()` after booting. + ## Hook names The canonical hook names on the bridge use a uniform `hook_in` / `hook_out` convention. The old TransformerLens names are preserved through an alias layer, so existing code keeps working without changes: @@ -92,14 +105,24 @@ These work identically on `TransformerBridge` and need no migration: - `to_tokens`, `to_string` - `generate` -- `run_with_hooks` -- `run_with_cache` -- `__call__` / `forward` +- `run_with_hooks`, `run_with_cache` — including batched-list inputs (parity fixed in 3.x) +- `__call__` / `forward` — accepts both 1D `[seq]` and 2D `[batch, seq]` token tensors - `cfg.*` — the bridge exposes a `.cfg` with the same fields (`n_layers`, `n_heads`, `d_model`, `d_vocab`, `n_ctx`, ...) - `W_Q`, `W_K`, `W_V`, `W_O`, `b_Q`, `b_K`, `b_V`, `b_O` — attention weights are exposed with the same `[n_heads, d_model, d_head]` shape conventions If your code only touches these APIs, the migration is genuinely just the loading call and (optionally) `enable_compatibility_mode`. +### New in 3.x: streaming generation + +Both `HookedTransformer` and `TransformerBridge` now expose `generate_stream`, which yields tokens progressively instead of returning the full completion at once: + +```python +for chunk in bridge.generate_stream("The quick brown fox", max_new_tokens=50): + print(chunk, end="", flush=True) +``` + +Same sampling kwargs as `generate` (`temperature`, `top_k`, `top_p`, `do_sample`, etc.). + ## Model name aliases are deprecated `HookedTransformer.from_pretrained` accepted a lot of short aliases (`"llama-7b-hf"`, `"gpt-neo-125M"`, etc.) that mapped to specific HuggingFace paths. The bridge accepts the official HuggingFace names directly, and emits a deprecation warning when you pass a legacy alias. The aliases will be removed in the next major version. diff --git a/docs/source/content/model_structure.md b/docs/source/content/model_structure.md index 653d0c911..593860edd 100644 --- a/docs/source/content/model_structure.md +++ b/docs/source/content/model_structure.md @@ -17,7 +17,7 @@ from transformer_lens.model_bridge import TransformerBridge bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") ``` -You can then call the familiar APIs: `to_tokens`, `to_string`, `generate`, `run_with_hooks`, `run_with_cache`. +You can then call the familiar APIs: `to_tokens`, `to_string`, `generate`, `generate_stream`, `run_with_hooks`, `run_with_cache`. ## Top-Level Components diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 9e24fa9fc..d0601aa63 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -173,6 +173,11 @@ def boot_transformers( ) -> "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). + Returns raw HF weights by default — logits/activations match HF, *not* + legacy ``HookedTransformer`` (which folds LayerNorm + centers weights). + Call ``enable_compatibility_mode()`` on the result for HookedTransformer- + equivalent numerics. Generation, argmax, and CE loss are unaffected. + Args: model_name: The name of the model to load. hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. @@ -617,10 +622,12 @@ def enable_compatibility_mode( fold_value_biases: bool = True, refactor_factored_attn_matrices: bool = False, ) -> None: - """Enable compatibility mode for the bridge. + """Apply HookedTransformer-equivalent weight processing and legacy hook compatibility. - This sets up the bridge to work with legacy TransformerLens components/hooks. - It will also disable warnings about the usage of legacy components/hooks if specified. + Defaults match HookedTransformer's load-time processing (fold_ln + weight + centering) — required for analyses that reason in HookedTransformer's + post-processed coordinate system: logit lens, direct logit attribution, + residual-stream norms. Also enables legacy hook/component name aliases. Args: disable_warnings: Whether to disable warnings about legacy components/hooks From ad8e123b7da494e8bffeba1e2f1c4b42b5058c30 Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Wed, 29 Apr 2026 07:00:09 -0700 Subject: [PATCH 18/21] Improved Tokenize & Concatenate (#1273) * Improved issue with tokenize and concatenate * dialed in the approach to be per-doc --- tests/unit/test_utils.py | 55 +++++++++++++++ transformer_lens/utilities/tokenize_utils.py | 72 ++++++++------------ 2 files changed, 83 insertions(+), 44 deletions(-) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 52934239f..d703dd997 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -483,6 +483,61 @@ def test_no_split_tokens_across_chunks(self): f"This indicates a word was split across chunk boundaries." ) + def test_no_split_tokens_in_no_whitespace_text(self): + """No-whitespace multi-doc input — the prior whitespace-lookahead fix + fell through to character cuts here. streaming=True keeps all docs in + one tokenize_function call so EOS markers actually exist to split on. + """ + from datasets import Dataset + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + docs = ["a" * 200 + "MilitaryVehicleEngine" * 100] * 10 + dataset = Dataset.from_dict({"text": docs}) + + result = utils.tokenize_and_concatenate( + dataset, + tokenizer, + streaming=True, + max_length=128, + add_bos_token=False, + ) + + full_text = tokenizer.eos_token.join(docs) + clean_tokens = tokenizer(full_text, return_tensors="np")["input_ids"].flatten() + clean_pairs = set(zip(clean_tokens[:-1], clean_tokens[1:])) + + output_tokens = np.concatenate([np.array(row["tokens"]) for row in result]) + for i in range(len(output_tokens) - 1): + pair = (int(output_tokens[i]), int(output_tokens[i + 1])) + assert pair in clean_pairs, ( + f"Token pair {pair} appears in tokenize_and_concatenate output " + f"but never occurs in natural tokenization. The chunker must " + f"have cut a token in half." + ) + + def test_single_document_batch_does_not_crash(self): + """Single-doc batch has no EOS to split on — fallback to one chunk should be correct.""" + from datasets import Dataset + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + dataset = Dataset.from_dict({"text": ["abcdefghij" * 200]}) + + result = utils.tokenize_and_concatenate( + dataset, + tokenizer, + streaming=True, + max_length=64, + add_bos_token=False, + ) + + clean = tokenizer("abcdefghij" * 200, return_tensors="np")["input_ids"].flatten() + output = np.concatenate([np.array(row["tokens"]) for row in result]) + n = len(output) + assert (output == clean[:n]).all() + def test_tokenize_and_concatenate_no_spurious_sequence_length_warning(): """Test that tokenize_and_concatenate does not emit the HF 'sequence length longer than maximum' warning.""" diff --git a/transformer_lens/utilities/tokenize_utils.py b/transformer_lens/utilities/tokenize_utils.py index 16a3ec4f7..c9b6dfefa 100644 --- a/transformer_lens/utilities/tokenize_utils.py +++ b/transformer_lens/utilities/tokenize_utils.py @@ -28,33 +28,30 @@ def tokenize_and_concatenate( add_bos_token: bool = True, num_proc: int = 10, ) -> Dataset: - """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end. + """Tokenize each document, join with token-level EOS between docs, and reshape into ``(batch, sequence_length)`` rows. - This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these) + Useful for training language models on a large text corpus without per-doc + truncation or padding. Absolute-position-embedding models also benefit by + avoiding early-token bias (e.g. news articles starting with "CNN"). Args: dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset. - tokenizer (PreTrainedTokenizerBase): The tokenizer. Assumed to have a bos_token_id and an eos_token_id. - streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False. + tokenizer (PreTrainedTokenizerBase): The tokenizer. Must have ``bos_token_id`` and ``eos_token_id``. + streaming (bool, optional): If True, avoids parallelism. Defaults to False. max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. - add_bos_token (bool, optional): . Defaults to True. + add_bos_token (bool, optional): Whether to prepend ``bos_token_id`` to each output row. Defaults to True. Returns: - Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens" + Dataset: Tokenized dataset of tensors with a single column ``"tokens"``. """ dataset = keep_single_column(dataset, column_name) has_pad_token = tokenizer.pad_token is not None if not has_pad_token: - # Add padding token for tokenizer (removed before model input) tokenizer.add_special_tokens({"pad_token": ""}) - # Define the length to chop things up into - leaving space for a bos_token if required - if add_bos_token: - seq_len = max_length - 1 - else: - seq_len = max_length + seq_len = max_length - 1 if add_bos_token else max_length - # Suppress the "sequence length longer than maximum" warning during chunked tokenization. + # Long docs legitimately exceed model_max_length; we slice into rows after. _deprecation_warnings_saved = None if hasattr(tokenizer, "deprecation_warnings"): _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy() @@ -63,50 +60,37 @@ def tokenize_and_concatenate( ] = False def tokenize_function(examples: Any) -> dict[str, np.ndarray]: - # datasets.map() may pass a LazyBatch, not a plain dict; accept dict-like batches text = examples[column_name] - # Concatenate it all into an enormous string, separated by eos_tokens assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token." - full_text = tokenizer.eos_token.join(text) - - # Handle the case when full_text is empty - if not full_text.strip(): + if not text: return {"tokens": np.array([], dtype=np.int64)} - # Split at whitespace boundaries to avoid mid-word tokens (#1133) - num_chunks = 20 - chunk_length = (len(full_text) - 1) // num_chunks + 1 - chunks = [] - start = 0 - lookahead = chunk_length // 10 - for i in range(num_chunks): - end = min(start + chunk_length, len(full_text)) - # Advance to whitespace; bounded lookahead for pathological inputs - boundary = min(end + lookahead, len(full_text)) - while end < boundary and not full_text[end].isspace(): - end += 1 - chunks.append(full_text[start:end]) - start = end - # Tokenize in parallel with NumPy (HF map rejects tensors) - tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() - # Drop padding tokens - tokens = tokens[tokens != tokenizer.pad_token_id] + # Per-doc tokenization with explicit token-level EOS — string chunking + # could cut tokens mid-doc (#1133); add_special_tokens=False prevents + # SentencePiece tokenizers from scattering auto-BOS/EOS per call. + encoded = tokenizer(text, add_special_tokens=False)["input_ids"] + eos_id = tokenizer.eos_token_id + pieces: list[np.ndarray] = [] + for i, row in enumerate(encoded): + pieces.append(np.asarray(row, dtype=np.int64)) + if i < len(encoded) - 1: + pieces.append(np.array([eos_id], dtype=np.int64)) + if not pieces: + return {"tokens": np.array([], dtype=np.int64)} + tokens = np.concatenate(pieces) num_tokens = len(tokens) - # Handle cases where num_tokens is less than seq_len if num_tokens < seq_len: num_batches = 1 - # Pad tokens if necessary tokens = tokens[:seq_len] if len(tokens) < seq_len: - padding_length = seq_len - len(tokens) - # Use EOS as pad to avoid out-of-vocabulary IDs + # Pad with EOS when no native pad token to avoid OOV IDs. padding_id = tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id - padding = np.full(padding_length, padding_id) - tokens = np.concatenate([tokens, padding], axis=0) + tokens = np.concatenate( + [tokens, np.full(seq_len - len(tokens), padding_id)], axis=0 + ) else: num_batches = num_tokens // seq_len - # Drop the final tokens if not enough to make a full sequence tokens = tokens[: seq_len * num_batches] tokens = einops.rearrange( From d95bd9627a516b304b54fe504de86abcf10a7baf Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Wed, 29 Apr 2026 08:01:26 -0700 Subject: [PATCH 19/21] Multi-Device Processing on Bridge (#1270) * Multi-GPU initial setup for TransformerBridge * Added additional documentation note --- .../model_bridge/test_multi_gpu_bridge.py | 299 ++++++++++++++++++ transformer_lens/model_bridge/bridge.py | 81 ++++- .../model_bridge/sources/transformers.py | 92 +++++- transformer_lens/utilities/__init__.py | 3 + transformer_lens/utilities/multi_gpu.py | 109 ++++++- 5 files changed, 566 insertions(+), 18 deletions(-) create mode 100644 tests/acceptance/model_bridge/test_multi_gpu_bridge.py diff --git a/tests/acceptance/model_bridge/test_multi_gpu_bridge.py b/tests/acceptance/model_bridge/test_multi_gpu_bridge.py new file mode 100644 index 000000000..58f84395d --- /dev/null +++ b/tests/acceptance/model_bridge/test_multi_gpu_bridge.py @@ -0,0 +1,299 @@ +"""Multi-GPU support tests for TransformerBridge. + +CPU-runnable tests exercise the resolver / param-plumbing / .to() guard / +validation logic. Tests requiring real multi-GPU hardware are marked skipif. +""" + +from typing import Dict, Union + +import pytest +import torch + +from transformer_lens.model_bridge import TransformerBridge +from transformer_lens.utilities.multi_gpu import ( + count_unique_devices, + find_embedding_device, + resolve_device_map, +) + +# ---------- CPU-runnable tests ---------- + + +class TestResolveDeviceMap: + def test_no_multi_device_returns_none(self): + dm, mm = resolve_device_map(None, None, None) + assert dm is None and mm is None + dm, mm = resolve_device_map(1, None, None) + assert dm is None and mm is None + dm, mm = resolve_device_map(0, None, None) + assert dm is None and mm is None + + def test_explicit_device_map_string_passes_through(self): + dm, mm = resolve_device_map(None, "auto", None) + assert dm == "auto" + assert mm is None + + def test_explicit_device_map_dict_passes_through(self): + explicit: Dict[str, Union[str, int]] = {"transformer.h.0": 0} + dm, mm = resolve_device_map(None, explicit, None) + assert dm is explicit + assert mm is None + + def test_user_max_memory_passes_through(self): + user_mm: Dict[Union[str, int], str] = {0: "20GiB"} + dm, mm = resolve_device_map(None, "auto", None, max_memory=user_mm) + assert dm == "auto" + assert mm is user_mm + + def test_device_and_device_map_mutually_exclusive(self): + with pytest.raises(ValueError, match="mutually exclusive"): + resolve_device_map(None, "auto", "cuda") + + def test_n_devices_without_cuda_raises(self): + if torch.cuda.is_available(): + pytest.skip("CUDA available; this test targets the no-CUDA path.") + with pytest.raises(ValueError, match="requires CUDA"): + resolve_device_map(2, None, None) + + def test_n_devices_exceeds_visible_raises(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA required.") + too_many = torch.cuda.device_count() + 1 + with pytest.raises(ValueError, match="only"): + resolve_device_map(too_many, None, None) + + def test_n_devices_returns_balanced_string_and_max_memory(self): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Requires 2+ CUDA devices.") + dm, mm = resolve_device_map(2, None, None) + # device_map must be a string directive (HF device_map dicts are keyed by + # submodule path — int keys would fail to match any submodule). + assert dm == "balanced" + assert isinstance(mm, dict) + assert set(mm.keys()) == {0, 1} + + def test_n_devices_respects_user_max_memory(self): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Requires 2+ CUDA devices.") + user_mm: Dict[Union[str, int], str] = {0: "10GiB", 1: "10GiB"} + dm, mm = resolve_device_map(2, None, None, max_memory=user_mm) + assert dm == "balanced" + assert mm == user_mm + + def test_cpu_value_in_device_map_rejected(self): + bad: Dict[str, Union[str, int]] = {"transformer.h.0": "cpu"} + with pytest.raises(ValueError, match="not supported"): + resolve_device_map(None, bad, None) + + def test_disk_value_in_device_map_rejected(self): + bad: Dict[str, Union[str, int]] = {"transformer.h.0": "disk"} + with pytest.raises(ValueError, match="not supported"): + resolve_device_map(None, bad, None) + + +class TestFindEmbeddingDevice: + def test_returns_none_for_no_device_map(self): + class Stub: + pass + + assert find_embedding_device(Stub()) is None + + def test_uses_get_input_embeddings_when_available(self): + # A stub model with both hf_device_map AND get_input_embeddings should + # consult the embedding module, not the first dict entry. This is the key + # difference from the insertion-order heuristic — covers the multimodal / + # encoder-decoder case where the first map entry is the vision tower. + embed = torch.nn.Embedding(10, 4) + embed = embed.to("cpu") + + class Stub: + hf_device_map = {"vision_tower.stuff": 1, "language_model.embed_tokens": "cpu"} + + def get_input_embeddings(self): + return embed + + result = find_embedding_device(Stub()) + assert result is not None + assert result.type == "cpu" + + def test_falls_back_to_first_entry_when_get_input_embeddings_unavailable(self): + class Stub: + hf_device_map = {"embed_tokens": "cpu", "layers.0": "cpu"} + + assert find_embedding_device(Stub()) == torch.device("cpu") + + def test_handles_int_device_ids_in_fallback(self): + class Stub: + hf_device_map = {"embed_tokens": 0, "layers.0": 1} + + result = find_embedding_device(Stub()) + assert result is not None + assert result.type == "cuda" + assert result.index == 0 + + def test_handles_get_input_embeddings_returning_none(self): + class Stub: + hf_device_map = {"embed_tokens": "cpu"} + + def get_input_embeddings(self): + return None + + assert find_embedding_device(Stub()) == torch.device("cpu") + + +class TestCountUniqueDevices: + def test_no_map_returns_1(self): + class Stub: + pass + + assert count_unique_devices(Stub()) == 1 + + def test_counts_unique_values(self): + class Stub: + hf_device_map = {"a": 0, "b": 0, "c": 1, "d": 1, "e": 2} + + assert count_unique_devices(Stub()) == 3 + + +class TestBootParamValidation: + def test_device_and_device_map_mutually_exclusive(self): + with pytest.raises(ValueError, match="mutually exclusive"): + TransformerBridge.boot_transformers("gpt2", device="cpu", device_map="auto") + + def test_preloaded_with_device_map_rejected(self, gpt2_bridge): + # Passing both hf_model= and device_map/n_devices is ambiguous — the device_map + # would be silently ignored. We raise so the caller isn't surprised. + with pytest.raises(ValueError, match="only supported when the bridge loads"): + TransformerBridge.boot_transformers( + "gpt2", hf_model=gpt2_bridge.original_model, device_map="auto" + ) + + def test_preloaded_with_n_devices_rejected(self, gpt2_bridge): + with pytest.raises(ValueError, match="only supported when the bridge loads"): + TransformerBridge.boot_transformers( + "gpt2", hf_model=gpt2_bridge.original_model, n_devices=2 + ) + + +class TestSingleDevicePathUnchanged: + def test_cpu_load_default_unchanged(self, gpt2_bridge): + # If any of our changes broke the baseline path, existing bridge tests would + # catch it too — this is a smoke check that n_devices stays 1 on the default path. + assert gpt2_bridge.cfg.n_devices == 1 + assert gpt2_bridge.cfg.device is not None + + +class TestToMethodGuardsMultiDevice: + def test_to_warns_and_drops_device_when_n_devices_gt_1(self, gpt2_bridge): + # Simulate a dispatched model by bumping n_devices — we don't need multi-GPU + # hardware to verify the .to() guard path. + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + with pytest.warns(UserWarning, match="ignored"): + gpt2_bridge.to("cpu") + assert next(gpt2_bridge.original_model.parameters()).device.type == "cpu" + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + def test_to_still_honors_dtype_under_multi_device(self, gpt2_bridge): + original_n_devices = gpt2_bridge.cfg.n_devices + original_dtype = next(gpt2_bridge.original_model.parameters()).dtype + gpt2_bridge.cfg.n_devices = 2 + try: + with pytest.warns(UserWarning, match="ignored"): + gpt2_bridge.to("cpu", torch.float64) + assert next(gpt2_bridge.original_model.parameters()).dtype == torch.float64 + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + gpt2_bridge.original_model.to(original_dtype) + + +class TestRunWithCacheGuardsMultiDevice: + def test_run_with_cache_device_arg_warns_under_multi_device(self, gpt2_bridge): + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + with pytest.warns(UserWarning, match="ignored"): + gpt2_bridge.run_with_cache(torch.tensor([[1, 2, 3]]), device="cpu") + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + +class TestStackedWeightsHandleCrossDevice: + def test_stack_gathers_across_devices(self, gpt2_bridge): + # Fake multi-device state by flipping cfg.n_devices. The GPT-2 bridge's weights + # still live on CPU, so gathering to cfg.device (also CPU) is a no-op — but the + # code path we care about (the `if n_devices > 1` branch) is exercised. + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + # None of these should raise, even with n_devices>1. + W_Q = gpt2_bridge.W_Q + W_K = gpt2_bridge.W_K + W_V = gpt2_bridge.W_V + W_O = gpt2_bridge.W_O + assert W_Q.shape[0] == gpt2_bridge.cfg.n_layers + assert W_K.shape[0] == gpt2_bridge.cfg.n_layers + assert W_V.shape[0] == gpt2_bridge.cfg.n_layers + assert W_O.shape[0] == gpt2_bridge.cfg.n_layers + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + def test_accumulated_bias_handles_cross_device(self, gpt2_bridge): + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + # Exercises the .to(accumulated.device) branch without requiring real GPUs. + bias = gpt2_bridge.accumulated_bias(layer=gpt2_bridge.cfg.n_layers - 1) + assert bias.shape == (gpt2_bridge.cfg.d_model,) + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + +# ---------- Multi-GPU tests (require real hardware) ---------- + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2+ CUDA devices") +class TestMultiDeviceIntegration: + def test_n_devices_matches_single_device_logits(self): + single = TransformerBridge.boot_transformers("gpt2", device="cuda:0") + multi = TransformerBridge.boot_transformers("gpt2", n_devices=2) + + assert multi.cfg.n_devices == 2 + assert single.cfg.n_devices == 1 + + tokens = torch.tensor([[1, 2, 3, 4]]) + logits_single = single(tokens).to("cpu") + logits_multi = multi(tokens).to("cpu") + assert torch.allclose(logits_single, logits_multi, atol=1e-4, rtol=1e-4) + + def test_parameters_distributed_across_devices(self): + bridge = TransformerBridge.boot_transformers("gpt2", n_devices=2) + cuda_indices = { + p.device.index for p in bridge.original_model.parameters() if p.device.type == "cuda" + } + assert cuda_indices == {0, 1} + + def test_generate_works_with_multi_device(self): + bridge = TransformerBridge.boot_transformers("gpt2", n_devices=2) + out = bridge.generate("Hello", max_new_tokens=3, do_sample=False) + assert isinstance(out, str) + assert len(out) > len("Hello") + + def test_stacked_weights_work_across_devices(self): + # Real multi-device exercise of _stack_block_params (no spoofed n_devices). + bridge = TransformerBridge.boot_transformers("gpt2", n_devices=2) + W_Q = bridge.W_Q + assert W_Q.shape[0] == bridge.cfg.n_layers + # After stacking, all elements should be on cfg.device (the embedding device). + assert bridge.cfg.device is not None + assert W_Q.device == torch.device(bridge.cfg.device) + + def test_preloaded_device_map_model(self): + from transformers import AutoModelForCausalLM + + hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto") + bridge = TransformerBridge.boot_transformers("gpt2", hf_model=hf_model) + assert bridge.cfg.n_devices >= 1 + assert bridge.cfg.device is not None diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index d0601aa63..7c9f7b915 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -169,6 +169,9 @@ def boot_transformers( trust_remote_code: bool = False, model_class: Optional[type] = None, hf_model: Optional[Any] = None, + device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, + n_devices: Optional[int] = None, + max_memory: Optional[Dict[Union[str, int], str]] = None, n_ctx: Optional[int] = None, ) -> "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). @@ -181,7 +184,8 @@ def boot_transformers( Args: model_name: The name of the model to load. hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. - device: The device to use. If None, will be determined automatically. + device: The device to use. If None, will be determined automatically. Mutually exclusive + with ``device_map``. dtype: The dtype to use for the model. tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. @@ -190,7 +194,17 @@ def boot_transformers( auto-detected class (e.g., BertForNextSentencePrediction). hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via - BitsAndBytesConfig). When provided, load_weights is ignored. + BitsAndBytesConfig). When provided, load_weights is ignored. If the pre-loaded + model was built with a ``device_map``, ``cfg.device`` and ``cfg.n_devices`` are + derived from its ``hf_device_map`` automatically. + device_map: HuggingFace-style device map for multi-GPU inference. Pass ``"auto"``, + ``"balanced"``, ``"sequential"``, or an explicit ``{submodule_path: device}`` dict. + Mutually exclusive with ``device``. + n_devices: Convenience shortcut: split the model across this many CUDA devices. + Translated to a ``max_memory`` dict over devices 0..n_devices-1 and passed as + ``device_map`` to HF. Requires CUDA with at least this many visible devices. + max_memory: Optional per-device memory budget, passed through to HF's dispatcher. + Only used when ``device_map`` or ``n_devices`` is in effect. n_ctx: Optional context length override. Writes to the appropriate HF config field for this model automatically (callers don't need to know the field name). Warns if larger than the model's default context length. @@ -210,6 +224,9 @@ def boot_transformers( trust_remote_code=trust_remote_code, model_class=model_class, hf_model=hf_model, + device_map=device_map, + n_devices=n_devices, + max_memory=max_memory, n_ctx=n_ctx, ) @@ -1109,6 +1126,12 @@ def _stack_block_params( if reshape_fn is not None: w = reshape_fn(w) weights.append(w) + # Under a device_map split, per-block tensors live on different devices. + # torch.stack requires a common device; gather onto cfg.device (the embedding / + # input device — a natural "home" for cross-layer reductions). + if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: + target_device = torch.device(self.cfg.device) + weights = [w.to(target_device) for w in weights] return torch.stack(weights, dim=0) def _reshape_qkv(self, w: torch.Tensor) -> torch.Tensor: @@ -1314,17 +1337,17 @@ def accumulated_bias( block = self.blocks[i] b_O = self._get_block_variant_bias(block) if b_O is not None: - accumulated = accumulated + b_O + accumulated = accumulated + b_O.to(accumulated.device) if include_mlp_biases and "mlp" in block._modules: b_out = getattr(block.mlp, "b_out", None) if b_out is not None: - accumulated = accumulated + b_out + accumulated = accumulated + b_out.to(accumulated.device) if mlp_input: assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" block = self.blocks[layer] b_O = self._get_block_variant_bias(block) if b_O is not None: - accumulated = accumulated + b_O + accumulated = accumulated + b_O.to(accumulated.device) return accumulated def all_composition_scores(self, mode: str) -> CompositionScores: @@ -1348,6 +1371,10 @@ def _stack(attr_path: str, reshape_fn: Optional[Callable] = None) -> torch.Tenso if reshape_fn is not None: w = reshape_fn(w) weights.append(w) + # See _stack_block_params: gather per-block tensors onto cfg.device when split. + if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: + target_device = torch.device(self.cfg.device) + weights = [w.to(target_device) for w in weights] return torch.stack(weights, dim=0) W_V = _stack("attn.W_V", self._reshape_qkv) @@ -1954,12 +1981,24 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: hooks.append((hook_dict[block_hook_name], block_hook_name)) filtered_kwargs = kwargs.copy() if cache_device is not None: - self.original_model = self.original_model.to(cache_device) - if processed_args and isinstance(processed_args[0], torch.Tensor): - processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:]) - for key, value in filtered_kwargs.items(): - if isinstance(value, torch.Tensor): - filtered_kwargs[key] = value.to(cache_device) + if getattr(self.cfg, "n_devices", 1) > 1: + # Moving a dispatched model to a single device collapses accelerate's + # split and breaks its routing hooks. The cache will stay spread across + # the per-layer devices; callers can .to(cache_device) on cache entries + # after the fact if they need a single-device cache. + warnings.warn( + f"run_with_cache(device={cache_device!r}) ignored: model is dispatched " + f"across {self.cfg.n_devices} devices via device_map. Cached activations " + "will remain on their per-layer devices.", + stacklevel=2, + ) + else: + self.original_model = self.original_model.to(cache_device) + if processed_args and isinstance(processed_args[0], torch.Tensor): + processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:]) + for key, value in filtered_kwargs.items(): + if isinstance(value, torch.Tensor): + filtered_kwargs[key] = value.to(cache_device) try: if "output_attentions" not in filtered_kwargs: filtered_kwargs["output_attentions"] = True @@ -3070,12 +3109,30 @@ def to(self, *args, **kwargs) -> "TransformerBridge": if "dtype" in kwargs: target_dtype = kwargs["dtype"] + # Moving a multi-device (device_map-dispatched) model to a single device would + # collapse the split and break accelerate's hook routing. Warn and drop the + # device move; still honor dtype changes. + if target_device is not None and getattr(self.cfg, "n_devices", 1) > 1: + warnings.warn( + f"TransformerBridge.to({target_device!r}) ignored: model is dispatched " + f"across {self.cfg.n_devices} devices via device_map. Reload with " + "device=... (and no device_map/n_devices) to move to a single device.", + stacklevel=2, + ) + target_device = None + if target_device is not None: move_to_and_update_config(self, target_device, print_details) if target_dtype is not None: move_to_and_update_config(self, target_dtype, print_details) - # Move the original model with all original args/kwargs (with print_details removed) + # Move the original model with all original args/kwargs (with print_details removed). + # When we've nulled target_device for multi-GPU safety, strip device args so the + # underlying module isn't moved either. + if target_device is None and (len(args) > 0 or "device" in kwargs): + kwargs.pop("device", None) + # Filter positional args: drop devices/strings, keep dtypes. + args = tuple(a for a in args if not isinstance(a, (torch.device, str))) self.original_model = self.original_model.to(*args, **kwargs) return self diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 99b90a968..afc002eee 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -287,13 +287,19 @@ def boot( model_class: Any | None = None, hf_model: Any | None = None, n_ctx: int | None = None, + # Experimental – Have not been fully tested on multi-gpu devices + # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues + device_map: str | dict[str, str | int] | None = None, + n_devices: int | None = None, + max_memory: dict[str | int, str] | None = None, ) -> TransformerBridge: """Boot a model from HuggingFace. Args: model_name: The name of the model to load. hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. - device: The device to use. If None, will be determined automatically. + device: The device to use. If None, will be determined automatically. Mutually exclusive + with ``device_map``. dtype: The dtype to use for the model. tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. @@ -303,6 +309,12 @@ def boot( hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). When provided, load_weights is ignored. + device_map: HuggingFace-style device map (``"auto"``, ``"balanced"``, dict, etc.) for + multi-GPU inference. Passed straight to ``from_pretrained``. Mutually exclusive + with ``device``. + n_devices: Convenience: split the model across this many CUDA devices (translated to a + ``max_memory`` dict internally). Requires CUDA with at least this many visible devices. + max_memory: Optional per-device memory budget for HF's dispatcher. n_ctx: Optional context length override. The bridge normally uses the model's documented max context from the HF config. Setting this writes to whichever HF field the model uses (n_positions / max_position_embeddings / etc.), so callers don't need to know @@ -430,9 +442,49 @@ def boot( if attn_logit_softcapping is not None: bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping) adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config) - if device is None: - device = get_device() - adapter.cfg.device = str(device) + # Pre-loaded models carry their own weight placement (possibly set by the caller via + # device_map). Passing device_map / n_devices / max_memory alongside hf_model= is + # ambiguous and would silently be ignored, so fail loudly. + if hf_model is not None and ( + device_map is not None or n_devices is not None or max_memory is not None + ): + raise ValueError( + "device_map / n_devices / max_memory are only supported when the bridge loads " + "the HF model itself. When passing hf_model=..., apply device_map via " + "AutoModel.from_pretrained before handing the model to the bridge." + ) + # Stateful/SSM (e.g. Mamba) models keep a per-layer recurrent cache that must live on + # that layer's device. The bridge currently allocates the stateful cache on a single + # cfg.device, so cross-device splits would silently misplace the cache. Block this + # combination until a v2 addresses per-layer stateful cache placement. + if (n_devices is not None and n_devices > 1) or device_map is not None: + if getattr(bridge_config, "is_stateful", False): + raise ValueError( + "Multi-device splits are not yet supported for stateful (SSM / Mamba) " + "architectures: the stateful cache allocation is single-device. " + "Load on one device, or wait for v2 support." + ) + # Resolve device_map before defaulting `device` — the two are mutually exclusive, and + # the resolver raises on conflict. If n_devices>1 is passed, it's translated into a + # device_map + max_memory pair here so downstream code only needs to check the + # resolved values. + from transformer_lens.utilities.multi_gpu import ( + count_unique_devices, + find_embedding_device, + resolve_device_map, + ) + + resolved_device_map, resolved_max_memory = resolve_device_map( + n_devices, device_map, device, max_memory + ) + if resolved_device_map is None: + if device is None: + device = get_device() + adapter.cfg.device = str(device) + else: + # cfg.device will be set from hf_device_map after the model is loaded. + # Provisionally keep it None; find_embedding_device fills it in below. + adapter.cfg.device = None if model_class is None: model_class = get_hf_model_class_for_architecture(architecture) # Ensure pad_token_id exists (v5 raises AttributeError if missing) @@ -447,6 +499,10 @@ def boot( model_kwargs["token"] = _hf_token if trust_remote_code: model_kwargs["trust_remote_code"] = True + if resolved_device_map is not None: + model_kwargs["device_map"] = resolved_device_map + if resolved_max_memory is not None: + model_kwargs["max_memory"] = resolved_max_memory if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None: model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation else: @@ -482,12 +538,38 @@ def boot( f"weight mismatch." ) from e raise - if device is not None: + # Skip explicit .to(device) when accelerate has placed weights via device_map. + if resolved_device_map is None and device is not None: hf_model = hf_model.to(device) # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq) for param in hf_model.parameters(): if param.is_floating_point() and param.dtype != dtype: param.data = param.data.to(dtype=dtype) + # Derive cfg.device / cfg.n_devices from hf_device_map when present. This covers: + # - fresh loads with a resolved device_map (set above) + # - pre-loaded hf_model that the caller dispatched themselves (e.g., device_map="auto") + hf_device_map_post = getattr(hf_model, "hf_device_map", None) + if hf_device_map_post: + # Pre-loaded path can still smuggle CPU/disk offload in; validate here too. + offload_values = {str(v).lower() for v in hf_device_map_post.values() if isinstance(v, str)} + forbidden = offload_values & {"cpu", "disk", "meta"} + if forbidden and ((n_devices is not None and n_devices > 1) or device_map is not None): + # Fresh-load path: we set the device_map ourselves, so this shouldn't happen — + # but if the user asked for n_devices>1 and somehow got CPU offload, surface it. + raise ValueError( + f"hf_device_map contains unsupported offload targets: {sorted(forbidden)}. " + "v1 multi-device support is GPU-only." + ) + embedding_device = find_embedding_device(hf_model) + if embedding_device is not None: + adapter.cfg.device = str(embedding_device) + adapter.cfg.n_devices = count_unique_devices(hf_model) + elif adapter.cfg.device is None: + # Pre-loaded single-device model with no hf_device_map — fall back to first param. + try: + adapter.cfg.device = str(next(hf_model.parameters()).device) + except StopIteration: + adapter.cfg.device = "cpu" # #7: Verify the n_ctx override actually took effect on the loaded model. # If HF's config class silently dropped or normalized the value, warn so # the user doesn't get misled into thinking longer sequences are supported. diff --git a/transformer_lens/utilities/__init__.py b/transformer_lens/utilities/__init__.py index de5e79db6..56975f90d 100644 --- a/transformer_lens/utilities/__init__.py +++ b/transformer_lens/utilities/__init__.py @@ -50,10 +50,13 @@ # Re-export multi-GPU helpers here (devices.py must not import multi_gpu directly) from .multi_gpu import ( calculate_available_device_cuda_memory, + count_unique_devices, determine_available_memory_for_available_devices, + find_embedding_device, get_best_available_cuda_device, get_best_available_device, get_device_for_block_index, + resolve_device_map, sort_devices_based_on_available_memory, ) from .slice import Slice, SliceInput diff --git a/transformer_lens/utilities/multi_gpu.py b/transformer_lens/utilities/multi_gpu.py index 8b5afaf58..0584604af 100644 --- a/transformer_lens/utilities/multi_gpu.py +++ b/transformer_lens/utilities/multi_gpu.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import torch @@ -141,3 +141,110 @@ def get_device_for_block_index( return device device_index = (device.index or 0) + (index // layers_per_device) return torch.device(device.type, device_index) + + +_UNSUPPORTED_DEVICE_MAP_VALUES = {"cpu", "disk", "meta"} +"""v1 multi-GPU scope is GPU-only. CPU offload and disk offload cause dtype-cast loops to +silently miss offloaded params (meta tensors), and cross-layer hook routing has different +semantics. Reject them explicitly until a v2 scopes those paths.""" + + +def _validate_device_map_values( + device_map: Union[str, Dict[str, Union[str, int]]], +) -> None: + """Reject CPU / disk / meta values in a user-supplied device_map dict.""" + if isinstance(device_map, str): + # "balanced_low_0" is fine — still GPU-only; "cpu" as a string-form device_map + # would tell HF to put everything on CPU, which is single-device and meaningless + # as a multi-GPU config. We allow strings through; HF will validate them. + return + for key, value in device_map.items(): + normalized = str(value).lower() if isinstance(value, str) else None + if normalized in _UNSUPPORTED_DEVICE_MAP_VALUES: + raise ValueError( + f"device_map[{key!r}]={value!r} is not supported. Multi-device bridge " + f"support is GPU-only in v1; CPU / disk / meta offload routes are excluded." + ) + + +def resolve_device_map( + n_devices: Optional[int], + device_map: Optional[Union[str, Dict[str, Union[str, int]]]], + device: Optional[Union[str, torch.device]], + max_memory: Optional[Dict[Union[str, int], str]] = None, +) -> Tuple[Optional[Union[str, Dict[str, Union[str, int]]]], Optional[Dict[Union[str, int], str]]]: + """Resolve ``n_devices`` / ``device_map`` / ``device`` into HF ``from_pretrained`` kwargs. + + Returns ``(device_map, max_memory)`` tuple ready to pass into ``model_kwargs``. + + Semantics: + - Explicit ``device_map`` wins; it's validated and passed through unchanged (user- + provided ``max_memory`` is passed through too). + - ``n_devices=None`` or ``1``: returns ``(None, None)`` — single-device path. + - ``n_devices > 1``: returns ``("balanced", {0: "auto", ..., n-1: "auto"})``. + ``"balanced"`` is accelerate's string directive for balanced layer dispatch; + the ``max_memory`` dict caps visibility to exactly ``n_devices`` GPUs. + """ + if device_map is not None and device is not None: + raise ValueError("device and device_map are mutually exclusive — pass one.") + if device_map is not None: + _validate_device_map_values(device_map) + return device_map, max_memory + if n_devices is None or n_devices <= 1: + return None, max_memory + if not torch.cuda.is_available(): + raise ValueError(f"n_devices={n_devices} requires CUDA, which is not available.") + if torch.cuda.device_count() < n_devices: + raise ValueError( + f"n_devices={n_devices} but only {torch.cuda.device_count()} CUDA devices present." + ) + resolved_max_memory: Dict[Union[str, int], str] = ( + dict(max_memory) if max_memory else {i: "auto" for i in range(n_devices)} + ) + return "balanced", resolved_max_memory + + +def find_embedding_device(hf_model: Any) -> Optional[torch.device]: + """Return the device that input tokens should be placed on for a dispatched HF model. + + When a model is loaded with ``device_map``, accelerate populates ``hf_device_map`` + and inserts pre/post-forward hooks that route activations. Input tensors must land on + the device of whichever module first *consumes* them — the input embedding. Returns + ``None`` for single-device models (no ``hf_device_map`` set). + + Resolves via ``hf_model.get_input_embeddings()`` rather than dict insertion order to + cover encoder-decoder / multimodal / audio architectures where the first entry in + ``hf_device_map`` is not the text-token embedding (e.g. the vision tower on LLaVA). + """ + hf_device_map = getattr(hf_model, "hf_device_map", None) + if not hf_device_map: + return None + # Preferred: ask the model for its input embedding module and read its device. + get_input_embeddings = getattr(hf_model, "get_input_embeddings", None) + if callable(get_input_embeddings): + try: + embed_module = get_input_embeddings() + except (AttributeError, NotImplementedError): + embed_module = None + if embed_module is not None: + try: + param = next(embed_module.parameters()) + return param.device + except StopIteration: + pass + # Fallback: first entry in hf_device_map. Less reliable but better than nothing. + first_device = next(iter(hf_device_map.values())) + if isinstance(first_device, int): + return torch.device("cuda", first_device) + return torch.device(first_device) + + +def count_unique_devices(hf_model: Any) -> int: + """Count the number of unique devices across a dispatched HF model's ``hf_device_map``. + + Returns 1 if the model has no ``hf_device_map`` (single-device load). + """ + hf_device_map = getattr(hf_model, "hf_device_map", None) + if not hf_device_map: + return 1 + return len(set(hf_device_map.values())) From fd288dc2c88668dafbec53306a828444be4b9035 Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Wed, 29 Apr 2026 16:58:46 -0500 Subject: [PATCH 20/21] Adding Architecture Adapter Creation Guide to Docs (#1274) * Adding architecture Adapter creation guide, add split QKV example to quantized LLaMA demo * ignore docs/build from black linting --- demos/LLaMA2_GPU_Quantized.ipynb | 179 +++++------ docs/source/_static/adapter-template.py | 162 ++++++++++ .../adapter-creation-guide.md | 276 ++++++++++++++++ .../adapter-specification.md | 300 ++++++++++++++++++ .../hf-model-analysis-guide.md | 168 ++++++++++ docs/source/content/contributing.md | 21 ++ pyproject.toml | 5 +- 7 files changed, 1008 insertions(+), 103 deletions(-) create mode 100644 docs/source/_static/adapter-template.py create mode 100644 docs/source/content/adapter_development/adapter-creation-guide.md create mode 100644 docs/source/content/adapter_development/adapter-specification.md create mode 100644 docs/source/content/adapter_development/hf-model-analysis-guide.md diff --git a/demos/LLaMA2_GPU_Quantized.ipynb b/demos/LLaMA2_GPU_Quantized.ipynb index 4722428a9..37d378cc8 100644 --- a/demos/LLaMA2_GPU_Quantized.ipynb +++ b/demos/LLaMA2_GPU_Quantized.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -33,9 +33,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Running as a Jupyter notebook - intended for development only!\n", - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" + "Running as a Jupyter notebook - intended for development only!\n" ] } ], @@ -74,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -105,7 +103,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "P8zS3MPkCUsR" }, @@ -186,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 11, "metadata": { "id": "RdJ0AuW_CUsS" }, @@ -253,7 +251,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -389,35 +387,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "02974f818bc54305b535861303ca208e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "config.json: 0%| | 0.00/843 [00:00\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -747,7 +668,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -794,6 +715,62 @@ "print(f\"Original Loss: {original_loss.item():.3f}\")\n", "print(f\"Ablated Loss: {ablated_loss.item():.3f}\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Per-head Q/K/V hooks via `use_split_qkv_input`\n", + "\n", + "By default, attention hooks like `hook_q`, `hook_k`, `hook_v` give the per-head **outputs** of the QKV projections. To intervene on the **inputs** to those projections at per-head granularity, set `use_split_qkv_input=True`. This unlocks `hook_q_input`, `hook_k_input`, and `hook_v_input` with shape `[batch, pos, n_heads, d_model]`.\n", + "\n", + "On legacy `HookedTransformer`, this combination broke when the model was loaded in 4-bit ([issue #737](https://github.com/TransformerLensOrg/TransformerLens/issues/737)) because the `bnb.matmul_4bit` reshape didn't account for the per-head input shape. `TransformerBridge` delegates 4-bit attention math to HuggingFace, so the combination works cleanly here." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of the q_input tensor: torch.Size([1, 32, 32, 2048])\n", + "Original Loss: 2.951\n", + "Q-input ablated Loss: 2.939\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "model.set_use_split_qkv_input(True)\n", + "\n", + "layer_to_ablate = 0\n", + "head_index_to_ablate = 4\n", + "\n", + "def q_input_ablation_hook(\n", + " q_input: Float[torch.Tensor, \"batch pos head_index d_model\"],\n", + " hook: HookPoint,\n", + ") -> Float[torch.Tensor, \"batch pos head_index d_model\"]:\n", + " print(f\"Shape of the q_input tensor: {q_input.shape}\")\n", + " q_input[:, :, head_index_to_ablate, :] = 0.0\n", + " return q_input\n", + "\n", + "original_loss = model(llama_tokens, return_type=\"loss\")\n", + "ablated_loss = model.run_with_hooks(\n", + " llama_tokens,\n", + " return_type=\"loss\",\n", + " fwd_hooks=[(\n", + " utils.get_act_name(\"q_input\", layer_to_ablate),\n", + " q_input_ablation_hook,\n", + " )],\n", + ")\n", + "print(f\"Original Loss: {original_loss.item():.3f}\")\n", + "print(f\"Q-input ablated Loss: {ablated_loss.item():.3f}\")\n", + "\n", + "model.set_use_split_qkv_input(False)" + ] } ], "metadata": { diff --git a/docs/source/_static/adapter-template.py b/docs/source/_static/adapter-template.py new file mode 100644 index 000000000..7b14928b0 --- /dev/null +++ b/docs/source/_static/adapter-template.py @@ -0,0 +1,162 @@ +""" architecture adapter. + +TODO: Replace with the actual model name throughout this file. +""" + +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + PositionEmbeddingsAttentionBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) + + +class ModelNameArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for models. + + TODO: Document which parameters are optional (missing biases, etc.) + + Optional Parameters (may not exist in state_dict): + ------------------------------------------------- + TODO: List parameters that may not exist. Example for models without biases: + + - blocks.{i}.attn.b_Q - No bias on query projection + - blocks.{i}.attn.b_K - No bias on key projection + - blocks.{i}.attn.b_V - No bias on value projection + - blocks.{i}.attn.b_O - No bias on output projection + - blocks.{i}.mlp.b_in - No bias on MLP input + - blocks.{i}.mlp.b_gate - No bias on MLP gate projection + - blocks.{i}.mlp.b_out - No bias on MLP output + - blocks.{i}.ln1.b - RMSNorm has no bias + - blocks.{i}.ln2.b - RMSNorm has no bias + - ln_final.b - RMSNorm has no bias + """ + + def __init__(self, cfg: Any) -> None: + """Initialize the architecture adapter.""" + super().__init__(cfg) + + # ===================================================================== + # 1. CONFIG ATTRIBUTES + # Set these based on the HuggingFace model's architecture. + # ===================================================================== + + # TODO: Set normalization type + # "RMS" for RMSNorm (Llama, Qwen, Gemma, etc.) + # "LN" for LayerNorm (GPT-2, GPT-J, etc.) + self.cfg.normalization_type = "RMS" + + # TODO: Set positional embedding type + # "rotary" for RoPE (Llama, Qwen, Mistral, etc.) + # "standard" for learned positional embeddings (GPT-2) + self.cfg.positional_embedding_type = "rotary" + + # TODO: Set these flags + self.cfg.final_rms = True # True if final layer norm is RMSNorm + self.cfg.gated_mlp = True # True if MLP has gate projection (SwiGLU) + self.cfg.attn_only = False # True only for attention-only models (rare) + self.cfg.uses_rms_norm = True # Should match normalization_type + + # TODO: Set the epsilon attribute name used by this model's normalization + # Check the HF model's norm layer to find the correct attribute name + self.cfg.eps_attr = "variance_epsilon" # or "layer_norm_eps", "rms_norm_eps", etc. + + # TODO: Handle GQA if applicable + # If the model uses Grouped Query Attention (n_key_value_heads < n_heads): + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.cfg.n_key_value_heads = cfg.n_key_value_heads + + # ===================================================================== + # 2. WEIGHT PROCESSING CONVERSIONS + # Defines how to reshape weights from HF format to TL format. + # For most models with separate Q/K/V/O, use the built-in helper. + # ===================================================================== + + self.weight_processing_conversions = { + **self._qkvo_weight_conversions(), + # TODO: Add any model-specific weight conversions here + } + + # ===================================================================== + # 3. COMPONENT MAPPING + # Maps TransformerLens canonical names to HuggingFace module paths. + # The `name=` parameter is the HF path relative to the model root + # (for top-level) or relative to the block (for block submodules). + # ===================================================================== + + # TODO: Replace all HF paths (name="...") with actual paths from the model. + # Inspect the HF model's named_modules() or config to find the correct paths. + self.component_mapping = { + # Token embedding + "embed": EmbeddingBridge(name="model.embed_tokens"), + # Rotary position embeddings (remove if model uses standard pos embeddings) + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), + # Transformer blocks + "blocks": BlockBridge( + name="model.layers", # TODO: HF path to the layer list + submodules={ + # Pre-attention layer norm + "ln1": RMSNormalizationBridge( + name="input_layernorm", # TODO: HF name within block + config=self.cfg, + ), + # Post-attention layer norm + "ln2": RMSNormalizationBridge( + name="post_attention_layernorm", # TODO: HF name within block + config=self.cfg, + ), + # Self-attention + "attn": PositionEmbeddingsAttentionBridge( + name="self_attn", # TODO: HF name within block + config=self.cfg, + submodules={ + "q": LinearBridge(name="q_proj"), # TODO: HF projection names + "k": LinearBridge(name="k_proj"), + "v": LinearBridge(name="v_proj"), + "o": LinearBridge(name="o_proj"), + }, + requires_attention_mask=True, + requires_position_embeddings=True, + ), + # MLP (gated) + "mlp": GatedMLPBridge( + name="mlp", # TODO: HF name within block + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate_proj"), # TODO: HF projection names + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + # Final layer norm + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + # Output head (unembedding) + "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), + } + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Set up model-specific references for component testing. + + TODO: Required for RoPE models. Remove if model uses standard positional embeddings. + """ + # Get rotary embedding instance from the HF model + rotary_emb = hf_model.model.rotary_emb # TODO: Adjust path if different + + # Set rotary_emb on actual bridge instances + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + + # Set on template for get_generalized_component() calls + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb) diff --git a/docs/source/content/adapter_development/adapter-creation-guide.md b/docs/source/content/adapter_development/adapter-creation-guide.md new file mode 100644 index 000000000..08fff9471 --- /dev/null +++ b/docs/source/content/adapter_development/adapter-creation-guide.md @@ -0,0 +1,276 @@ +# Architecture Adapter Creation Guide + +A walkthrough for developers writing a new Architecture Adapter for the TransformerLens `TransformerBridge` system. This guide distills the process of developing an adapter into a set of steps that can be followed start-to-finish. + +If you just want the API reference, jump to [adapter-specification.md](adapter-specification.md). If you have a specific HF model in hand and want a config-extraction cookbook, see [hf-model-analysis-guide.md](hf-model-analysis-guide.md). This document ties those together with workflow and review practice. + +## What an adapter is + +An **Architecture Adapter** is a Python class that extends `ArchitectureAdapter` and tells `TransformerBridge` three things about a HuggingFace model: + +1. **Config attributes** — set on `self.cfg` in `__init__` (normalization type, positional embedding type, GQA params, etc.) +2. **Component mapping** — `self.component_mapping`, a dict mapping TransformerLens canonical names (`embed`, `blocks`, `attn.q`, …) to `GeneralizedComponent` Bridge instances pointed at HF module paths. +3. **Weight processing conversions** — `self.weight_processing_conversions`, a dict of tensor-reshape rules that translate HF weight layouts to TL layouts during loading. + +Once registered, users can `boot_transformers("")` and get a fully hooked TransformerLens model with weights loaded from HF. + +## Prerequisites + +Before starting, make sure you can: + +- Read PyTorch model code and trace a forward pass +- Run a HF model locally with `transformers` +- Use `model.named_modules()` and `model.state_dict()` to inspect structure +- Identify whether a model uses RoPE vs learned positional embeddings, RMSNorm vs LayerNorm, gated vs standard MLP, separate vs joint QKV, MHA vs GQA vs MQA. ([hf-model-analysis-guide.md](hf-model-analysis-guide.md) has a decision tree.) + +You do **not** need to memorize every Bridge component — the existing adapters in `transformer_lens/model_bridge/supported_architectures/` are your reference library. + +## Analyze the architecture + +### Read the HF source + +Open the two files that define the architecture in `transformers`: + +- `models//modeling_.py` — the model code +- `models//configuration_.py` — the config class + +Read every `__init__` and every `forward`. You are looking for: + +- Module hierarchy: what's nested in what, named how +- Forward pass order: norm before/after attention? residual where? +- Bias presence on each linear layer +- Normalization type and the *exact* attribute name of its epsilon (`variance_epsilon`, `rms_norm_eps`, `layer_norm_eps`, `eps`, …) +- Attention type (MHA / GQA / MQA) and whether QKV are separate or joint +- MLP type (gated / standard) and projection names +- Anything that looks weird (special scaling, conditional padding, dtype upcasts in softmax, …) + +Also extract the standard config-to-TL field mapping (see [hf-model-analysis-guide.md](hf-model-analysis-guide.md) for the table). + +### Find the closest reference adapter + +Almost every new model is a variant of an existing pattern. Pick the nearest match from `supported_architectures/`: + +| If your model is like… | Start from… | +|---------------------------------------|--------------------------------------| +| Llama, Mistral, Qwen2, Gemma, OLMo | `llama.py` | +| Qwen2/Qwen3 (gated config, MLPBridge) | `qwen2.py` | +| GPT-2, GPT-J, GPT-Neo | `gpt2.py` | +| BLOOM, Falcon | `bloom.py` or `falcon.py` | +| T5 / encoder-decoder | `t5.py` | +| MoE | `mixtral.py` or `granite_moe.py` | +| Multimodal (vision+text) | `llava.py` or `gemma3_multimodal.py` | + +### Write down what you found + +Before writing any adapter code, take notes on the architecture. This is for your own use, it does not need to be formally documented. It will help inform your decisions going forward. + +At minimum, capture: + +- **Source files** — exact paths in `transformers` +- **Module hierarchy** — every HF module path you'll need, with line numbers in the source where it's defined +- **Config fields** — the HF names and their TL equivalents +- **Architectural properties** — normalization, position embeddings, attention type, MLP type, biases +- **Forward pass flow** — order of operations in the block, attention, and MLP +- **Reference adapter** — closest existing adapter, and a list of every way your target differs from it +- **Representative models** — small variants (≤7B parameters) you'll use for verification + +## Implement the adapter + +### File layout + +- **Adapter file:** `transformer_lens/model_bridge/supported_architectures/.py` +- **Class name:** `ArchitectureAdapter` (e.g. `LlamaArchitectureAdapter`) +- **Module name:** lowercase + underscores (`llama.py`, `qwen2.py`, `granite_moe.py`) + +Start from [adapter-template.py](../../_static/adapter-template.py). It's a Llama-pattern skeleton with TODOs at every decision point. + +A reasonable order for filling it in: + +1. Config attributes (drives everything else) +2. Weight processing conversions +3. Component mapping +4. Optional overrides (only the ones you actually need) +5. Registration + +### Config attributes + +Set these on `self.cfg` in `__init__` *before* building the component mapping (the bridges read from `self.cfg`): + +| Attribute | Type | Purpose | +|----------------------------|--------|-----------------------------------------------| +| `normalization_type` | `str` | `"RMS"` or `"LN"` | +| `positional_embedding_type`| `str` | `"rotary"` or `"standard"` | +| `final_rms` | `bool` | Final norm is RMSNorm | +| `gated_mlp` | `bool` | MLP has gate projection (SwiGLU) | +| `attn_only` | `bool` | Model has no MLP layers (rare) | +| `uses_rms_norm` | `bool` | Should match `normalization_type == "RMS"` | +| `eps_attr` | `str` | HF attribute name for norm epsilon | + +For GQA models, also forward `n_key_value_heads`: + +```python +if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.cfg.n_key_value_heads = cfg.n_key_value_heads +``` + +### Component mapping + +For each TL canonical name, instantiate the right Bridge component and point its `name=` parameter at the HF module path (relative to the model root for top-level entries, relative to the block for block submodules). + +A standard Llama-style mapping: + +```python +self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), + "attn": PositionEmbeddingsAttentionBridge( + name="self_attn", + config=self.cfg, + submodules={ + "q": LinearBridge(name="q_proj"), + "k": LinearBridge(name="k_proj"), + "v": LinearBridge(name="v_proj"), + "o": LinearBridge(name="o_proj"), + }, + requires_attention_mask=True, + requires_position_embeddings=True, + ), + "mlp": GatedMLPBridge( + name="mlp", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), +} +``` + +The full bridge component catalog (attention variants, MLP variants, specialized bridges for BLOOM/CLIP/Siglip/T5/MoE/etc.) is in [adapter-specification.md](adapter-specification.md) under "Available Bridge Components." + +### Weight processing conversions + +For models with separate Q/K/V/O projections, use the built-in helper: + +```python +self.weight_processing_conversions = { + **self._qkvo_weight_conversions(), +} +``` + +It generates the standard `(n h) m -> n m h` rearrangements with the right head/kv-head counts. + +For combined-QKV models (GPT-2 style), see `gpt2.py`'s `QKVSplitRearrangeConversion` for the pattern. For other oddball layouts, define custom `ParamProcessingConversion` or `RearrangeTensorConversion` instances. + +### Optional overrides + +Implement only the ones you need: + +- **`setup_component_testing(hf_model, bridge_model=None)`** — required for RoPE models, to wire the rotary embedding instance through to the attention bridges. Skip for models with standard positional embeddings. +- **`preprocess_weights(state_dict)`** — for arch-specific weight transforms before standard processing (e.g., Gemma scales embeddings by `sqrt(d_model)`). +- **`prepare_loading(model_name, model_kwargs)`** — patch HF model classes before `from_pretrained()`. +- **`prepare_model(hf_model)`** — post-load fixups before bridge creation. + +### Registration + +Two files to update: + +1. `transformer_lens/model_bridge/supported_architectures/__init__.py` — add the import and append to `__all__`. +2. `transformer_lens/factories/architecture_adapter_factory.py` — add to the import block and to `SUPPORTED_ARCHITECTURES`: + + ```python + "": , + ``` + +Forgetting registration is the most common silent failure — the adapter exists but `boot_transformers` can't find it. + +### Tests + +Write tests that exercise actual behavior: + +- Hook names resolve correctly +- Weight shapes match expectations after loading +- Forward pass produces sensible output for a tiny variant + +### New bridge components + +Don't add a new bridge unless the existing ones can't express your model. The bar is: the `forward()` must be fundamentally different from any existing bridge. If you do add one: + +- Place it in `transformer_lens/model_bridge/generalized_components/` +- Export it from the package `__init__` +- Write tests covering its forward pass and any state it carries + +## Verify the adapter + +The `verify_models` tool runs a real HF model side-by-side with your bridge and compares activations across four phases. Each phase produces a numeric score; the model passes if all phase scores meet their thresholds. + +### Pick models + +From your representative-models list, take the smallest variants (prefer ≤7B parameters), up to 5, sorted by HuggingFace download count. Verifying multiple sizes catches scaling bugs that single-model verification misses. + +### Run verification + +One model at a time, with float32 by default: + +```bash +uv run python -m transformer_lens.tools.model_registry.verify_models \ + --model \ + --max-memory \ + --device cpu \ + --dtype float32 \ + --no-ht-reference +``` + +If a model OOMs with float32, retry that single model with `--dtype bfloat16`. Set `--max-memory` to roughly 75-85% of your device memory, to ensure adequate space for running the benchmarks. + +### Read the status + +Each model gets a status: + +- **status=1** — passed, move to the next model +- **status=2** — skipped by `verify_models` (e.g., exceeded the memory pre-check). Note it and move on; not an adapter bug. +- **status=3** — phase score failure. Stop and fix. Read the `note` and the per-phase scores, find the root cause, fix the adapter, re-verify. + +### Lint + +After all chosen models pass: + +```bash +uv run mypy . +make check-format +``` + +Both must be clean. Don't paper over mypy errors with `# type: ignore` — fix the underlying type. If mypy is wrong about something, that's a real issue worth investigating, not silencing. + +## Before you open a PR + +`verify_models` will catch most numerical bugs, but a few things are worth a once-over by eye. + +**Sanity-check against the HF source.** Skim your adapter with the HF `modeling_.py` open alongside it. Module paths, config attribute names, and bias presence are the usual suspects — easy to get wrong from memory and easy to spot when you look directly. + +**Watch for the subtle stuff.** When the adapter reimplements a computation or defines weight conversions, the things that bite are operation order (split before or after the layernorm?), dtype upcasting in softmax, and conditional logic that only fires under certain conditions in HF (e.g., flash-attention paths). If something in your code looks like it "probably matches" HF, that's a good place to stop and check. + +**Don't reach for abstraction prematurely.** If you've added a base class or protocol with only one or two concrete uses, you're probably better off without it. The same goes for config knobs that don't have a current consumer. + +**Confirm the boring stuff is done.** Both registration sites (`__init__.py` and `architecture_adapter_factory.py`), `mypy` and format checks clean, tests doing real work rather than asserting mocks return their mock values. + +## Common pitfalls + +- **Wrong `eps_attr` name.** Models that look identical use different attribute names (`variance_epsilon`, `rms_norm_eps`, `eps`). Read the norm class. +- **Forgetting `n_key_value_heads`.** Without it, GQA models silently reshape weights as if they were MHA — verification fails with cryptic shape errors. +- **Missing registration.** Adapter exists but the factory can't find it. Update both `__init__.py` and `architecture_adapter_factory.py`. +- **Skipping `setup_component_testing` for RoPE.** Rotary embeddings need to be wired through to each attention bridge or component testing produces nonsense. +- **Reusing `model.norm` when the path is `model.final_layernorm`.** Module paths look similar across architectures but rarely match exactly — always verify against the actual HF source. +- **Tautological tests.** "Test that mock returns mock_value" is not a test. Tests should exercise real shapes, real forward passes, real hook resolution. +- **`# type: ignore` on mypy errors.** Find the root cause; the type error is usually telling you something real about the bridge config. +- **Coding before the architecture is understood.** The single biggest time-waster. Five pages of code based on a wrong assumption about module paths is worse than no code. diff --git a/docs/source/content/adapter_development/adapter-specification.md b/docs/source/content/adapter_development/adapter-specification.md new file mode 100644 index 000000000..17798b6b6 --- /dev/null +++ b/docs/source/content/adapter_development/adapter-specification.md @@ -0,0 +1,300 @@ +--- +orphan: true +--- + +# Architecture Adapter Specification + +This document is the primary reference for building Architecture Adapters for the TransformerLens TransformerBridge system. + +## What Is an Architecture Adapter? + +An Architecture Adapter is a Python class that extends `ArchitectureAdapter` (from `transformer_lens.model_bridge.architecture_adapter`). It maps between a HuggingFace model's internal structure and TransformerLens's canonical component names. Every adapter must define three things: + +1. **Config attributes** — set on `self.cfg` in `__init__` +2. **Component mapping** — `self.component_mapping` dict mapping TL names to Bridge instances +3. **Weight processing conversions** — `self.weight_processing_conversions` dict for tensor reshaping + +## File Location and Naming + +- **Adapter file:** `transformer_lens/model_bridge/supported_architectures/.py` +- **Class name:** `ArchitectureAdapter` (e.g., `LlamaArchitectureAdapter`) +- **Module name:** lowercase, underscores (e.g., `llama.py`, `qwen2.py`, `granite_moe.py`) + +## Registration Checklist + +After creating the adapter, register it in these files: + +1. **`transformer_lens/model_bridge/supported_architectures/__init__.py`** + - Add import: `from transformer_lens.model_bridge.supported_architectures. import ` + - Add to `__all__` list + +2. **`transformer_lens/factories/architecture_adapter_factory.py`** + - Add import (in the existing import block from `supported_architectures`) + - Add entry to `SUPPORTED_ARCHITECTURES` dict: `"": ` + +## Config Attributes + +Set these on `self.cfg` in `__init__` before building the component mapping: + +| Attribute | Type | Description | Examples | +|-----------|------|-------------|----------| +| `normalization_type` | `str` | `"RMS"` or `"LN"` | Llama="RMS", GPT2="LN" | +| `positional_embedding_type` | `str` | `"rotary"` or `"standard"` | Llama="rotary", GPT2="standard" | +| `final_rms` | `bool` | Whether final layer norm is RMS | Llama=True, GPT2=False | +| `gated_mlp` | `bool` | Whether MLP uses gate projection | Llama=True, GPT2=False | +| `attn_only` | `bool` | Whether model has no MLP layers | Usually False | +| `uses_rms_norm` | `bool` | Redundant with normalization_type but needed | Match normalization_type | +| `eps_attr` | `str` | Attribute name for norm epsilon | `"variance_epsilon"`, `"layer_norm_eps"` | + +### GQA (Grouped Query Attention) + +If the model uses GQA (n_key_value_heads < n_heads), set: +```python +if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.cfg.n_key_value_heads = cfg.n_key_value_heads +``` + +## Component Mapping + +`self.component_mapping` is a `dict[str, GeneralizedComponent]` mapping TransformerLens canonical names to Bridge instances. The Bridge `name=` parameter is the HuggingFace module path. + +### Standard Mapping (Llama-style decoder-only) + +```python +self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), + "attn": PositionEmbeddingsAttentionBridge( + name="self_attn", + config=self.cfg, + submodules={ + "q": LinearBridge(name="q_proj"), + "k": LinearBridge(name="k_proj"), + "v": LinearBridge(name="v_proj"), + "o": LinearBridge(name="o_proj"), + }, + requires_attention_mask=True, + requires_position_embeddings=True, + ), + "mlp": GatedMLPBridge( + name="mlp", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), +} +``` + +### GPT2-style Mapping (standard positional embeddings, combined QKV) + +```python +self.component_mapping = { + "embed": EmbeddingBridge(name="transformer.wte"), + "pos_embed": PosEmbedBridge(name="transformer.wpe"), + "blocks": BlockBridge( + name="transformer.h", + config=self.cfg, + submodules={ + "ln1": NormalizationBridge(name="ln_1", config=self.cfg), + "attn": JointQKVAttentionBridge( + name="attn", + config=self.cfg, + submodules={ + "qkv": LinearBridge(name="c_attn"), + "o": LinearBridge(name="c_proj"), + }, + ), + "ln2": NormalizationBridge(name="ln_2", config=self.cfg), + "mlp": MLPBridge( + name="mlp", + submodules={ + "in": LinearBridge(name="c_fc"), + "out": LinearBridge(name="c_proj"), + }, + ), + }, + ), + "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head"), +} +``` + +> **Note:** GPT2's `MLPBridge` and `UnembeddingBridge` do not pass `config=`. The `config` parameter is optional on these bridges — match the existing adapter's pattern. + +## Weight Processing Conversions + +`self.weight_processing_conversions` maps TransformerLens weight paths to `ParamProcessingConversion` instances that handle tensor reshaping during weight loading. + +### Standard QKVO Conversions (most models) + +For models with separate Q/K/V/O projections, use the built-in helper: + +```python +self.weight_processing_conversions = { + **self._qkvo_weight_conversions(), +} +``` + +This generates rearrangement rules for: +- `blocks.{i}.attn.q.weight` — `(n h) m -> n m h` with `n=n_heads` +- `blocks.{i}.attn.k.weight` — `(n h) m -> n m h` with `n=n_kv_heads` +- `blocks.{i}.attn.v.weight` — `(n h) m -> n m h` with `n=n_kv_heads` +- `blocks.{i}.attn.o.weight` — `m (n h) -> n h m` with `n=n_heads` + +### Custom Conversions + +For models with non-standard weight layouts (e.g., combined QKV), define custom `ParamProcessingConversion` or `RearrangeTensorConversion` instances. See `gpt2.py` for the `QKVSplitRearrangeConversion` example. + +## Available Bridge Components + +### Core Components + +| Component | Use When | +|-----------|----------| +| `EmbeddingBridge` | Token embeddings | +| `UnembeddingBridge` | Output head (lm_head) | +| `BlockBridge` | Transformer block container (always named "blocks") | +| `LinearBridge` | Any linear/projection layer | + +### Normalization + +| Component | Use When | +|-----------|----------| +| `NormalizationBridge` | LayerNorm | +| `RMSNormalizationBridge` | RMSNorm | + +### Attention + +| Component | Use When | +|-----------|----------| +| `AttentionBridge` | Basic attention (no positional embeddings passed) | +| `PositionEmbeddingsAttentionBridge` | Attention that receives position embeddings (RoPE models) | +| `JointQKVAttentionBridge` | Combined QKV single linear layer (GPT-2 style) | +| `JointQKVPositionEmbeddingsAttentionBridge` | Combined QKV with position embeddings | + +### MLP + +| Component | Use When | +|-----------|----------| +| `MLPBridge` | Standard 2-layer MLP (in/out) or with separate gate | +| `GatedMLPBridge` | Gated MLP with gate/up/down projections (SwiGLU) | +| `JointGateUpMLPBridge` | MLP where gate and up projections are fused | + +### Position Embeddings + +| Component | Use When | +|-----------|----------| +| `PosEmbedBridge` | Learned positional embeddings (GPT-2 style) | +| `RotaryEmbeddingBridge` | Rotary position embeddings (RoPE) | + +### Specialized + +| Component | Use When | +|-----------|----------| +| `MoEBridge` | Mixture of Experts routing | +| `SymbolicBridge` | Placeholder/container with no direct HF module | +| `Conv1DBridge` | 1D convolution layers | +| `T5BlockBridge` | T5-specific block structure | +| `CLIPVisionEncoderBridge` | CLIP vision encoder (multimodal) | +| `CLIPVisionEncoderLayerBridge` | Individual CLIP vision encoder layer | +| `SiglipVisionEncoderBridge` | Siglip vision encoder (multimodal) | +| `SiglipVisionEncoderLayerBridge` | Individual Siglip vision encoder layer | +| `VisionProjectionBridge` | Vision-to-text projection (multimodal) | + +### Architecture-Specific (Bloom/Falcon) + +These exist for architectures with non-standard internal structures. Discover them by reading the reference adapter. + +| Component | Use When | +|-----------|----------| +| `BloomBlockBridge` | BLOOM transformer blocks | +| `BloomAttentionBridge` | BLOOM attention mechanism | +| `BloomMLPBridge` | BLOOM MLP | +| `AudioFeatureExtractorBridge` | Audio feature extraction (HuBERT) | +| `ConvPosEmbedBridge` | Convolutional positional embeddings (HuBERT) | + +## Optional Overrides + +### `setup_component_testing(hf_model, bridge_model=None)` + +Called after adapter creation. Use to set up model-specific references for component testing. Required for RoPE models to set rotary embedding references: + +```python +def setup_component_testing(self, hf_model, bridge_model=None): + rotary_emb = hf_model.model.rotary_emb + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb) +``` + +### `preprocess_weights(state_dict)` + +Apply architecture-specific weight transformations before standard processing. Example: Gemma scales embeddings by `sqrt(d_model)`. + +### `prepare_loading(model_name, model_kwargs)` + +Called before `from_pretrained()`. Use to patch HF model classes. + +### `prepare_model(hf_model)` + +Called after model loading but before bridge creation. Use for post-load fixups. + +## Common Architecture Patterns + +### Pattern 1: Llama-like (most modern models) + +RoPE + RMSNorm + GatedMLP + separate Q/K/V/O. Uses `GatedMLPBridge`. Used by: Llama, Mistral, Gemma, OLMo, Granite, StableLM. + +**Qwen2 variant:** Nearly identical to Llama but uses `MLPBridge` instead of `GatedMLPBridge` (while still setting `gated_mlp = True` and having gate/in/out submodules). Used by: Qwen2, Qwen3. + +### Pattern 2: GPT2-like + +Standard positional embeddings + LayerNorm + standard MLP + combined QKV. Used by: GPT-2, GPT-J, GPT-Neo/NeoX. + +### Pattern 3: MoE (Mixture of Experts) + +Similar to Llama-like but with `MoEBridge` replacing the MLP. Used by: Mixtral, GraniteMoE, OLMoE. + +### Pattern 4: Multimodal + +Extends a text-only pattern with vision encoder and projection bridges. Used by: LLaVA, LLaVA-Next, Gemma3 Multimodal. + +## Imports Template + +```python +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, # or MLPBridge for non-gated + LinearBridge, + PositionEmbeddingsAttentionBridge, # or JointQKVAttentionBridge + RMSNormalizationBridge, # or NormalizationBridge for LayerNorm + RotaryEmbeddingBridge, # only for RoPE models + UnembeddingBridge, +) +``` + +## Testing + +After creating an adapter, verify it by: + +1. Running the adapter-specific unit tests +2. Loading a small model variant with `boot_transformers(model_name)` +3. Verifying hook names resolve correctly +4. Checking that weight shapes match expectations diff --git a/docs/source/content/adapter_development/hf-model-analysis-guide.md b/docs/source/content/adapter_development/hf-model-analysis-guide.md new file mode 100644 index 000000000..4885adc98 --- /dev/null +++ b/docs/source/content/adapter_development/hf-model-analysis-guide.md @@ -0,0 +1,168 @@ +# HuggingFace Model Analysis Guide + +This guide explains how to analyze a HuggingFace model to extract the information needed to build a TransformerLens Architecture Adapter. + +## Read the model's config.json + +Every HF model has a `config.json` that contains architecture details. You can access it via: + +```python +from transformers import AutoConfig +config = AutoConfig.from_pretrained("model-name-or-path") +print(config) +``` + +Or via the HuggingFace API: +```bash +curl -s "https://huggingface.co/model-name/resolve/main/config.json" | python -m json.tool +``` + +### Key config fields to extract + +| HF Config Field | TL Config Field | Description | +|-----------------|-----------------|-------------| +| `hidden_size` | `d_model` | Model dimension | +| `num_attention_heads` | `n_heads` | Number of attention heads | +| `num_key_value_heads` | `n_key_value_heads` | KV heads (for GQA; if absent or equal to n_heads, not GQA) | +| `intermediate_size` | `d_mlp` | MLP intermediate dimension | +| `num_hidden_layers` | `n_layers` | Number of transformer blocks | +| `vocab_size` | `d_vocab` | Vocabulary size | +| `max_position_embeddings` | `n_ctx` | Maximum sequence length | +| `rms_norm_eps` | `eps` | Normalization epsilon | +| `model_type` | — | Architecture family (e.g., "llama", "gpt2", "mistral") | +| `architectures` | `architecture` | HF class name (e.g., `["LlamaForCausalLM"]`) | + +## Determine architecture characteristics + +### Normalization type + +Check the model code or config: +- **RMSNorm** → `normalization_type = "RMS"` — Look for `RMSNorm` in the model code, or `rms_norm_eps` in config +- **LayerNorm** → `normalization_type = "LN"` — Look for `LayerNorm`, or `layer_norm_eps` / `layer_norm_epsilon` in config + +Also identify the epsilon attribute name: +- `"variance_epsilon"` (Llama) +- `"rms_norm_eps"` (some models expose this directly) +- `"layer_norm_eps"` (GPT-2, BERT) +- `"eps"` (generic) + +### Positional embedding type + +- **Rotary (RoPE)** → `positional_embedding_type = "rotary"` — Most modern models (Llama, Mistral, Qwen, Gemma) +- **Learned/Standard** → `positional_embedding_type = "standard"` — GPT-2, OPT +- Check for `RotaryEmbedding` class in the model code + +### Attention type + +- **Multi-Head Attention (MHA)** — `n_key_value_heads == n_heads` or field absent +- **Grouped Query Attention (GQA)** — `n_key_value_heads < n_heads` (e.g., Llama 3, Mistral) +- **Multi-Query Attention (MQA)** — `n_key_value_heads == 1` (e.g., Falcon) + +### MLP type + +- **Gated MLP (SwiGLU)** → `gated_mlp = True` — Has gate/up/down projections (Llama, Qwen, Gemma) +- **Standard MLP** → `gated_mlp = False` — Has fc1/fc2 or c_fc/c_proj (GPT-2) + +### QKV layout + +- **Separate Q/K/V** — Most models: `q_proj`, `k_proj`, `v_proj` +- **Combined QKV** — GPT-2 style: single `c_attn` or `query_key_value` linear layer + +## Inspect module names + +To find the exact HuggingFace module paths for the component mapping: + +```python +from transformers import AutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype="auto") + +# Print all named modules +for name, module in model.named_modules(): + print(f"{name}: {type(module).__name__}") +``` + +### What to look for + +Map these HF module paths to TL component mapping entries: + +| TL Name | Look for in HF | Common HF Paths | +|---------|----------------|-----------------| +| `embed` | Token embedding | `model.embed_tokens`, `transformer.wte` | +| `pos_embed` | Position embedding (if standard) | `transformer.wpe` | +| `rotary_emb` | Rotary embedding (if RoPE) | `model.rotary_emb`, `model.layers.0.self_attn.rotary_emb` | +| `blocks` | Layer list | `model.layers`, `transformer.h`, `model.decoder.layers` | +| `ln1` | Pre-attention norm | `input_layernorm`, `ln_1` | +| `ln2` | Post-attention norm | `post_attention_layernorm`, `ln_2` | +| `attn` | Self-attention module | `self_attn`, `attn` | +| `attn.q` | Query projection | `q_proj`, `query` | +| `attn.k` | Key projection | `k_proj`, `key` | +| `attn.v` | Value projection | `v_proj`, `value` | +| `attn.o` | Output projection | `o_proj`, `out_proj`, `dense`, `c_proj` | +| `attn.qkv` | Combined QKV (if used) | `c_attn`, `query_key_value` | +| `mlp` | MLP module | `mlp`, `feed_forward` | +| `mlp.gate` | Gate projection (if gated) | `gate_proj`, `w1` | +| `mlp.in` | Up/input projection | `up_proj`, `c_fc`, `fc1`, `w3` | +| `mlp.out` | Down/output projection | `down_proj`, `c_proj`, `fc2`, `w2` | +| `ln_final` | Final layer norm | `model.norm`, `transformer.ln_f`, `model.final_layernorm` | +| `unembed` | LM head | `lm_head`, `embed_out` | + +## Check for biases + +```python +# Check if a specific layer has bias +layer = model.model.layers[0] +print(f"Q bias: {layer.self_attn.q_proj.bias is not None}") +print(f"MLP in bias: {layer.mlp.up_proj.bias is not None}") +``` + +Document which layers lack biases — this affects the "Optional Parameters" section of the adapter docstring. + +## Examine state dict keys + +```python +# Print all parameter names and shapes +for key, param in model.state_dict().items(): + print(f"{key}: {param.shape}") +``` + +This helps verify: +- Weight naming patterns match your component mapping +- Tensor shapes match expected dimensions +- No unexpected parameters that need special handling + +## Find an existing similar adapter + +Check if a similar architecture already has an adapter. Most new models are variants of existing patterns: + +| If your model is like... | Start from adapter... | +|--------------------------|----------------------| +| Llama, Mistral, Qwen2, Gemma | `llama.py` | +| GPT-2, GPT-J | `gpt2.py` | +| BLOOM, Falcon | `bloom.py` or `falcon.py` | +| T5, encoder-decoder | `t5.py` | +| MoE model | `mixtral.py` or `granite_moe.py` | +| Multimodal (vision+text) | `llava.py` or `gemma3_multimodal.py` | + +## Quick reference: decision tree + +``` +1. Does the model use RMSNorm or LayerNorm? + → RMSNorm: normalization_type="RMS", use RMSNormalizationBridge + → LayerNorm: normalization_type="LN", use NormalizationBridge + +2. Does the model use RoPE or learned positional embeddings? + → RoPE: positional_embedding_type="rotary", add RotaryEmbeddingBridge, use PositionEmbeddingsAttentionBridge + → Learned: positional_embedding_type="standard", add PosEmbedBridge + +3. Are Q/K/V separate or combined? + → Separate: use PositionEmbeddingsAttentionBridge with q/k/v/o submodules + → Combined: use JointQKVAttentionBridge with qkv/o submodules + +4. Does the MLP have a gate projection? + → Yes (gate+up+down): gated_mlp=True, use GatedMLPBridge + → No (in+out): gated_mlp=False, use MLPBridge + +5. Is n_key_value_heads < n_heads? + → Yes: GQA — set n_key_value_heads on cfg + → No: standard MHA — no special handling needed +``` diff --git a/docs/source/content/contributing.md b/docs/source/content/contributing.md index df514ab7f..370a87361 100644 --- a/docs/source/content/contributing.md +++ b/docs/source/content/contributing.md @@ -158,3 +158,24 @@ must be repeated (i.e. `\\`). You can write LaTeX inline, or in "display mode". - Numbered items - `1. Item` - Quotes - indent one level - External links = ``` `Link text ` ``` + +## Creating Architecture Adapters + +If a HuggingFace model is not yet supported by `TransformerBridge`, you can add support by writing an Architecture Adapter. An adapter is a Python class that tells the bridge how a particular HF model maps to TransformerLens's canonical component names (`embed`, `blocks`, `attn.q`, etc.). Once registered, `TransformerBridge.boot_transformers("")` will load the model end-to-end with full hook support. + +The work is mostly bookkeeping: identify each component on the HF side (embeddings, attention, MLP, normalization), point a Bridge instance at the corresponding HF module path, and supply tensor-reshape rules where the weight layout differs from TransformerLens conventions. Most of the per-architecture decisions are already encoded in the existing adapters under `transformer_lens/model_bridge/supported_architectures/`, which are good starting points to copy from. + +Two guides walk through the process: + +- [Architecture Adapter Creation Guide](adapter_development/adapter-creation-guide.md) — start here. A step-by-step workflow for taking an HF model from unsupported to tested, registered adapter. +- [HuggingFace Model Analysis Guide](adapter_development/hf-model-analysis-guide.md) — a reference for reading an HF model's `config.json` and source files to extract the attributes you'll set on `self.cfg`. + +Adapters live in `transformer_lens/model_bridge/supported_architectures/.py` and are registered in two places: `supported_architectures/__init__.py` and `factories/architecture_adapter_factory.py`. Both steps are covered in the creation guide. If you want a starter file, copy [adapter-template.py](../_static/adapter-template.py) into `supported_architectures/` and rename it. + +```{toctree} +:hidden: +:maxdepth: 1 + +adapter_development/adapter-creation-guide +adapter_development/hf-model-analysis-guide +``` diff --git a/pyproject.toml b/pyproject.toml index 1af32755f..e20e94ce6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,11 +105,12 @@ [tool.black] line-length=100 # Set line length to 100 to match other tools - # Exclude snapshot tests & .venv + # Exclude snapshot tests, .venv, and Sphinx build output (autogenerated) exclude=''' ( /snapshots/ -| .venv/ +| \.venv/ +| docs/build/ ) ''' From 0a5218ca38ed55466cacfeac0ce20d40e25a9b0f Mon Sep 17 00:00:00 2001 From: Jonah Larson Date: Wed, 29 Apr 2026 18:06:24 -0500 Subject: [PATCH 21/21] Fixed Quantization bug in TransformerLens 3.0 (#1276) * Fixed Quantization bug in TransformerLens 3.0 * Format fixes --- .../model_bridge/test_bridge_integration.py | 80 +++++++++++++++++++ .../generalized_components/attention.py | 13 ++- .../generalized_components/base.py | 13 ++- 3 files changed, 98 insertions(+), 8 deletions(-) diff --git a/tests/integration/model_bridge/test_bridge_integration.py b/tests/integration/model_bridge/test_bridge_integration.py index d9b11228c..febe910b0 100644 --- a/tests/integration/model_bridge/test_bridge_integration.py +++ b/tests/integration/model_bridge/test_bridge_integration.py @@ -718,6 +718,86 @@ def hook_fn(grad, hook=None): assert hook_called["bridge"], "TransformerBridge backward hook should now be called correctly" +def test_AttentionBridge_preserves_fp_input_when_first_param_is_quantized(): + """Bridge must not cast fp inputs to integer storage dtype. + + Regression for an AttentionBridge / GeneralizedComponent bug where + `target_dtype = next(parameters()).dtype` returned the storage dtype of + quantized weights (uint8 for BnB Params4bit, int32 for GPTQ, etc.). When + the first parameter happened to be quantized, bridge cast fp32 hidden_states + to that integer dtype before passing them to HF — destroying precision and + producing gibberish logits on every quantized model. + + Fakes a "quantized first parameter" by replacing q_proj.weight with a + uint8 tensor, then runs a forward and asserts the input the original + component receives is still floating-point. + """ + from transformer_lens.model_bridge.generalized_components.attention import ( + AttentionBridge, + ) + + # Use tiny Mistral — it's a plain AttentionBridge (not JointQKV). + bridge: TransformerBridge = TransformerBridge.boot_transformers( # type: ignore + "trl-internal-testing/tiny-MistralForCausalLM-0.2", device="cpu" + ) + + attn_bridge = bridge.blocks[0].attn # type: ignore[attr-defined] + assert ( + type(attn_bridge).__name__ == "AttentionBridge" + ), f"Expected plain AttentionBridge, got {type(attn_bridge).__name__}" + assert isinstance(attn_bridge, AttentionBridge) + + original = attn_bridge.original_component + assert original is not None, "AttentionBridge.original_component not set" + + # Fake-quantize q_proj to uint8 storage — mirrors BnB Params4bit. + fp_weight = original.q_proj.weight + original.q_proj.weight = torch.nn.Parameter( + torch.zeros(fp_weight.shape, dtype=torch.uint8), requires_grad=False + ) + assert ( + next(original.parameters()).dtype == torch.uint8 + ), "Test setup: first param should be uint8 to trigger the bug condition" + + # Capture what dtype reaches the original component's forward. + received_dtype: list = [] + orig_forward = original.forward + + def capture(*args, **kwargs): + if "hidden_states" in kwargs: + received_dtype.append(kwargs["hidden_states"].dtype) + elif args: + received_dtype.append(args[0].dtype) + # Don't actually run forward — fake-quantized weight would error. + # Return a shape-compatible dummy. HF Mistral attention returns a tuple. + bsz, seq, d_model = (kwargs.get("hidden_states", args[0] if args else None)).shape + n_heads = bridge.cfg.n_heads # type: ignore[attr-defined] + return ( + torch.zeros(bsz, seq, d_model, dtype=torch.float32), + torch.zeros(bsz, n_heads, seq, seq, dtype=torch.float32), + ) + + original.forward = capture # type: ignore[method-assign] + try: + test_input = torch.tensor([[1, 2, 3, 4, 5]]) + with torch.no_grad(): + try: + bridge(test_input) + except Exception: + pass # downstream may fail; we only care what reached attn forward + finally: + original.forward = orig_forward # type: ignore[method-assign] + original.q_proj.weight = fp_weight + + assert len(received_dtype) > 0, "Original attention forward never called" + for dt in received_dtype: + assert dt.is_floating_point, ( + f"Bridge passed dtype={dt} to original attention forward, but it must be " + f"floating point. Regression of the AttentionBridge dtype-cast bug — " + f"target_dtype must skip non-fp (quantized-storage) parameters." + ) + + @pytest.mark.skipif(bool(os.getenv("CI")), reason="Skip Gemma2 test in CI to avoid timeout") def test_TransformerBridge_gemma2_forward(): """Test that TransformerBridge properly handles Gemma2's position_embeddings. diff --git a/transformer_lens/model_bridge/generalized_components/attention.py b/transformer_lens/model_bridge/generalized_components/attention.py index 34da4e280..22504b294 100644 --- a/transformer_lens/model_bridge/generalized_components/attention.py +++ b/transformer_lens/model_bridge/generalized_components/attention.py @@ -622,11 +622,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: raise RuntimeError( f"Original component not set for {self.name}. Call set_original_component() first." ) + # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32, + # HQQ, torchao) are stored in integer dtypes and dequantized internally + # during matmul. The compute dtype must come from a fp parameter; casting + # fp inputs to an integer storage dtype destroys precision. target_dtype = None - try: - target_dtype = next(self.original_component.parameters()).dtype - except StopIteration: - pass + for p in self.original_component.parameters(): + if not p.dtype.is_floating_point: + continue + target_dtype = p.dtype + break if "query_input" in kwargs: hooked = self.hook_in(kwargs["query_input"]) if ( diff --git a/transformer_lens/model_bridge/generalized_components/base.py b/transformer_lens/model_bridge/generalized_components/base.py index 20e44fbbb..ae6787b67 100644 --- a/transformer_lens/model_bridge/generalized_components/base.py +++ b/transformer_lens/model_bridge/generalized_components/base.py @@ -274,11 +274,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: raise RuntimeError( f"Original component not set for {self.name}. Call set_original_component() first." ) + # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32, + # HQQ, torchao) are stored in integer dtypes and dequantized internally + # during matmul. The compute dtype must come from a fp parameter; casting + # fp inputs to an integer storage dtype destroys precision. target_dtype = None - try: - target_dtype = next(original_component.parameters()).dtype - except StopIteration: - pass + for p in original_component.parameters(): + if not p.dtype.is_floating_point: + continue + target_dtype = p.dtype + break input_arg_names = [ "input", "hidden_states",