From 58b007f8bb0be10c056a82b6e673c98ff679bc2c Mon Sep 17 00:00:00 2001 From: Brendan Long Date: Mon, 20 Apr 2026 12:50:01 -0700 Subject: [PATCH 01/15] Fix type of HookedTransformerConfig.device (#1230) * Fix type of HookedTransformerConfig.device This is typed as `Optional[str]` but sometimes returns `torch.device`. Updated the code to just return the `str` instead of wrapping with a device. I'm not confident that every function which takes a device will always be passed a string, so I didn't change functions like warn_if_mps. Found while working on #1219 * more cleanup * 3.0 CI Bugs (#1261) * Fixing `utils` imports * skip gated notebooks on PR from forks * Updating notebooks * Ensure LLaMA only runs when HF_TOKEN is available --------- Co-authored-by: jlarson4 --- .github/workflows/checks.yml | 15 +- demos/Activation_Patching_in_TL_Demo.ipynb | 34 +- demos/Attribution_Patching_Demo.ipynb | 216 +- demos/Exploratory_Analysis_Demo.ipynb | 4 +- demos/Grokking_Demo.ipynb | 20 +- demos/Head_Detector_Demo.ipynb | 80 +- demos/Interactive_Neuroscope.ipynb | 9 +- demos/LLaMA.ipynb | 2 +- demos/LLaMA2_GPU_Quantized.ipynb | 2 +- demos/Main_Demo.ipynb | 2 +- demos/No_Position_Experiment.ipynb | 2760 ++++++++--------- demos/Othello_GPT.ipynb | 4 +- demos/SVD_Interpreter_Demo.ipynb | 2 +- demos/Santa_Coder.ipynb | 2 +- .../test_apertus.py | 2 +- tests/unit/utilities/test_devices.py | 12 +- transformer_lens/ActivationCache.py | 4 +- transformer_lens/HookedEncoderDecoder.py | 2 +- transformer_lens/HookedTransformer.py | 8 +- .../components/abstract_attention.py | 2 +- transformer_lens/components/bert_block.py | 2 +- transformer_lens/components/pos_embed.py | 2 +- transformer_lens/components/t5_block.py | 2 +- .../components/transformer_block.py | 2 +- .../config/HookedTransformerConfig.py | 4 +- transformer_lens/evals.py | 4 +- transformer_lens/head_detector.py | 2 +- transformer_lens/hook_points.py | 2 +- transformer_lens/loading_from_pretrained.py | 2 +- transformer_lens/model_bridge/bridge.py | 2 +- .../model_bridge/sources/transformers.py | 2 +- transformer_lens/patching.py | 2 +- transformer_lens/train.py | 8 +- transformer_lens/utilities/devices.py | 14 +- .../utilities/exploratory_utils.py | 4 +- transformer_lens/weight_processing.py | 2 +- 36 files changed, 1636 insertions(+), 1602 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 86d809f4a..3a2e1389c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -221,6 +221,8 @@ jobs: name: Notebook Checks runs-on: ubuntu-latest timeout-minutes: 30 + env: + HAS_HF_TOKEN: ${{ secrets.HF_TOKEN != '' }} strategy: fail-fast: false matrix: @@ -233,8 +235,6 @@ jobs: # - "Grokking_Demo" - "Head_Detector_Demo" # - "Interactive_Neuroscope" - - "LLaMA" - - "LLaMA2_GPU_Quantized" # Requires quantization libs + too slow for CI timeout - "Main_Demo" # - "No_Position_Experiment" - "Othello_GPT" @@ -242,6 +242,12 @@ jobs: - "Santa_Coder" # - "stable_lm" - "T5" + requires_hf_token: [false] + include: + - notebook: "LLaMA" + requires_hf_token: true + - notebook: "LLaMA2_GPU_Quantized" + requires_hf_token: true steps: - uses: actions/checkout@v3 - name: Add swap space @@ -278,10 +284,13 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} - name: Check Notebook Output Consistency - # Note: currently only checks notebooks we have specifically setup for this + if: ${{ !matrix.requires_hf_token || env.HAS_HF_TOKEN == 'true' }} run: pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/${{ matrix.notebook }}.ipynb env: HF_TOKEN: ${{ secrets.HF_TOKEN }} + - name: Skip (HF_TOKEN not available) + if: ${{ matrix.requires_hf_token && env.HAS_HF_TOKEN != 'true' }} + run: echo "Skipping ${{ matrix.notebook }} — requires HF_TOKEN for gated model access" build-docs: diff --git a/demos/Activation_Patching_in_TL_Demo.ipynb b/demos/Activation_Patching_in_TL_Demo.ipynb index 17fc8f004..cbd7378cb 100644 --- a/demos/Activation_Patching_in_TL_Demo.ipynb +++ b/demos/Activation_Patching_in_TL_Demo.ipynb @@ -126,7 +126,7 @@ "outputs": [], "source": [ "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.model_bridge import TransformerBridge" ] }, @@ -674,7 +674,7 @@ ], "source": [ "# NBVAL_SKIP\n", - "# Heavy patching computation \u2014 too slow for CI without GPU\n", + "# Heavy patching computation — too slow for CI without GPU\n", "every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n", "imshow(every_block_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])" ] @@ -885,7 +885,21 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "# NBVAL_IGNORE_OUTPUT\ndef induction_loss(logits, answer_token_indices=rand_tokens_A):\n seq_len = answer_token_indices.shape[1]\n\n # logits: batch x seq_len x vocab_size\n # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n final_logits = logits[:, -seq_len:-1]\n final_log_probs = final_logits.log_softmax(-1)\n return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\nCLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\nprint(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\nCORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\nprint(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)" + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "def induction_loss(logits, answer_token_indices=rand_tokens_A):\n", + " seq_len = answer_token_indices.shape[1]\n", + "\n", + " # logits: batch x seq_len x vocab_size\n", + " # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n", + " final_logits = logits[:, -seq_len:-1]\n", + " final_log_probs = final_logits.log_softmax(-1)\n", + " return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\n", + "CLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\n", + "print(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\n", + "CORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\n", + "print(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)" + ] }, { "cell_type": "markdown", @@ -899,7 +913,17 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "# NBVAL_SKIP\n# Heavy patching computation \u2014 too slow for CI without GPU\nevery_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\nimshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n\nif DO_SLOW_RUNS:\n every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens_induction[0]))], y=[f\"L{l}H{h}\" for l in range(attn_only.cfg.n_layers) for h in range(attn_only.cfg.n_heads)])" + "source": [ + "# NBVAL_SKIP\n", + "# Heavy patching computation — too slow for CI without GPU\n", + "every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n", + "imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n", + "\n", + "if DO_SLOW_RUNS:\n", + " every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n", + " every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n", + " imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens_induction[0]))], y=[f\"L{l}H{h}\" for l in range(attn_only.cfg.n_layers) for h in range(attn_only.cfg.n_heads)])" + ] }, { "cell_type": "markdown", @@ -1289,4 +1313,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/demos/Attribution_Patching_Demo.ipynb b/demos/Attribution_Patching_Demo.ipynb index 7cc944ea7..8d9fce790 100644 --- a/demos/Attribution_Patching_Demo.ipynb +++ b/demos/Attribution_Patching_Demo.ipynb @@ -179,7 +179,7 @@ "outputs": [], "source": [ "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens import (\n", " ActivationCache,\n", ")\n", @@ -3806,7 +3806,7 @@ "Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n", "Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n", "Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n", - "Top 3th token. Logit: 14.58 Prob: 0.21% Token: | \u00c9|\n", + "Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n", "Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n", "Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n", "Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n", @@ -4236,11 +4236,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d230304b88114f2a9b85f5a48f441ce6", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_ee535e2cfa694be1a7857b1867b8b608", "tabbable": null, "tooltip": null, - "value": "\u2007456k/?\u2007[00:00<00:00,\u200712.7MB/s]" + "value": " 456k/? [00:00<00:00, 12.7MB/s]" } }, "020cf001eb7d496295a325cbc0ee8718": { @@ -4277,7 +4277,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_af67816d50074ae498ef9b600b4175ed", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7ac7b11ef3e34bf1a12926c745e08707", "tabbable": null, "tooltip": null, @@ -4300,11 +4300,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_7a32c6104cdb4eee8dadc248c129040c", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_8c0e7d4b46c14e2bb2167752820a9274", "tabbable": null, "tooltip": null, - "value": "\u20071.36M/?\u2007[00:00<00:00,\u200717.1MB/s]" + "value": " 1.36M/? [00:00<00:00, 17.1MB/s]" } }, "043e0e7fe43744589b7bad2527c2eac0": { @@ -4323,11 +4323,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d3dab66a1c254f07afa02e73e6fd121d", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_a386fa811d524ae08ac67cce5ebf3a15", "tabbable": null, "tooltip": null, - "value": "\u20071.36M/?\u2007[00:00<00:00,\u200720.0MB/s]" + "value": " 1.36M/? [00:00<00:00, 20.0MB/s]" } }, "04d1b3296c75497bb314206d6c7d5341": { @@ -4346,11 +4346,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_8e6cf78296b14bc381f13658ebf99912", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_0ebc4d3f1e94415086e749f4cd41b783", "tabbable": null, "tooltip": null, - "value": "\u20071.04M/?\u2007[00:00<00:00,\u200710.6MB/s]" + "value": " 1.04M/? [00:00<00:00, 10.6MB/s]" } }, "0582e71e725a4851a1905aceaa3c36ae": { @@ -4669,11 +4669,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_2569c461e9144c4c82e856e4533449ba", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_2f332b079a3044fa8ab87f028a7b80b0", "tabbable": null, "tooltip": null, - "value": "\u200748/48\u2007[01:10<00:00,\u2007\u20071.46s/it]" + "value": " 48/48 [01:10<00:00,  1.46s/it]" } }, "0fd7652c5e624ef7b2a36a0b0397f51d": { @@ -4805,11 +4805,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_5cc6434927224f72aadd34dc0e0c2894", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_432ff9b9f3574d69b47e8272dd762923", "tabbable": null, "tooltip": null, - "value": "merges.txt:\u2007" + "value": "merges.txt: " } }, "12f960167a1c417aacdd77ce3a997e35": { @@ -4828,7 +4828,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_8dd28e04200641a9a2a4e5ed241db518", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_912fc8506a6f45e38f1573d27eff6457", "tabbable": null, "tooltip": null, @@ -4875,11 +4875,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_ae5d013bec884f4b97d1852b1fb52432", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_87c101f3cc0f4553870ca9de688b9e83", "tabbable": null, "tooltip": null, - "value": "vocab.json:\u2007" + "value": "vocab.json: " } }, "1775bd14b2104a078aa63991cc11ba85": { @@ -4967,11 +4967,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_6dd11b3c888a461aaa85372b044ccd53", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_dc90adc8272e4e3e929844e2ceef149b", "tabbable": null, "tooltip": null, - "value": "\u2007124/124\u2007[00:00<00:00,\u200754.7kB/s]" + "value": " 124/124 [00:00<00:00, 54.7kB/s]" } }, "180d2ba6e10e4e808eba69a8517d5080": { @@ -4990,7 +4990,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_de109b45a18f42b5aa83f63ef683379f", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_4e44c9999b664ea9bad1b4e360ee76c7", "tabbable": null, "tooltip": null, @@ -5686,7 +5686,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_59190b0bd8e74ee1bab6aea2f931856d", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_c0aa3c04c0a74717b8fc6700213bf579", "tabbable": null, "tooltip": null, @@ -5796,11 +5796,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_700b88d4443848bab341b9d7b00cab54", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_b2e784a339524df682698e606959668e", "tabbable": null, "tooltip": null, - "value": "tokenizer.json:\u2007" + "value": "tokenizer.json: " } }, "31a28b69348b40bfbd14a54380bfb766": { @@ -5819,7 +5819,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_eddca0196bdf4e1eb5605e557bfe597b", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_c762da254b434ffea5b7c35e73009302", "tabbable": null, "tooltip": null, @@ -6314,7 +6314,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_5ed82326612a4505a34bc16d6b0b5fa8", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_69d19b1cf82443ff8994ffd7b156921c", "tabbable": null, "tooltip": null, @@ -6337,11 +6337,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_c9571f91e4894ac0ab6f9433d6dd7258", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_c08b955fa9494958bee9f565c568fc31", "tabbable": null, "tooltip": null, - "value": "tokenizer.json:\u2007" + "value": "tokenizer.json: " } }, "3c377339930f49a5891caeb0639a8360": { @@ -6561,11 +6561,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_68f6906692b447f1acec3cef5772fe5a", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_67cda37b9ae74d3484442b7d3bb19a26", "tabbable": null, "tooltip": null, - "value": "merges.txt:\u2007" + "value": "merges.txt: " } }, "4000e5115c6d48d687ab9b9695a0d826": { @@ -6600,7 +6600,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d94d7a4f5ab34fc9a4a4ee0a07764461", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_0fd7652c5e624ef7b2a36a0b0397f51d", "tabbable": null, "tooltip": null, @@ -6821,11 +6821,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_212c5660764844db8cdc6e3a16099521", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7de33c6a558d40a69a85b3db9e203ae8", "tabbable": null, "tooltip": null, - "value": "model.safetensors:\u2007100%" + "value": "model.safetensors: 100%" } }, "46c907a0ac31481f9147bf22e2ac5864": { @@ -6897,11 +6897,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_91db9856db97451196e16d433896af48", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_4e68d17b4b7a45369d6917887ddd7a28", "tabbable": null, "tooltip": null, - "value": "vocab.json:\u2007" + "value": "vocab.json: " } }, "49be2b480d5847a3af7835c317236280": { @@ -7253,11 +7253,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d7e105c660824d349c4ee17006f04437", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_2848d61b3098431baa1d82fb85f469fe", "tabbable": null, "tooltip": null, - "value": "\u2007548M/548M\u2007[00:18<00:00,\u200760.2MB/s]" + "value": " 548M/548M [00:18<00:00, 60.2MB/s]" } }, "54af6102260d458db54e634c9814aa6f": { @@ -7294,11 +7294,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_6d3de9443ae74b75930d8397e9c7ed9a", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_63c8f7fdcf6346d9b3ca140d0f63f8e4", "tabbable": null, "tooltip": null, - "value": "\u2007144/144\u2007[00:07<00:00,\u200718.78it/s]" + "value": " 144/144 [00:07<00:00, 18.78it/s]" } }, "54daaf323029464b9b67f8a4f53b3002": { @@ -7333,7 +7333,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_39f8ccf31f6c49c7a95c59989236a3cf", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_54af6102260d458db54e634c9814aa6f", "tabbable": null, "tooltip": null, @@ -7653,11 +7653,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_7ffa006c49564aa8ad58f08f48b98955", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_357dcb56edba4564b6ec3051f3e977a5", "tabbable": null, "tooltip": null, - "value": "\u2007144/144\u2007[00:08<00:00,\u200715.78it/s]" + "value": " 144/144 [00:08<00:00, 15.78it/s]" } }, "5f730b9ef10e4b8bb97c4fef1bd7cbb2": { @@ -7789,7 +7789,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_4b378e2fe92a4bb5a0cc2adee8a9372d", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_8e4b9fcabfbc4a37a54f08a99e28b220", "tabbable": null, "tooltip": null, @@ -7838,11 +7838,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d7699d95a0ab4240bfa2754ac81a4dea", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_b92108f127ec4341af59d110b5f991c4", "tabbable": null, "tooltip": null, - "value": "Loading\u2007weights:\u2007100%" + "value": "Loading weights: 100%" } }, "61ffe9ae7ab44e94b5729cae80e49437": { @@ -7998,11 +7998,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_38e4ab89d9ed4f16a2a481054a18977f", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_3673315e0fe741048d2ba304360be671", "tabbable": null, "tooltip": null, - "value": "Loading\u2007weights:\u2007100%" + "value": "Loading weights: 100%" } }, "67cda37b9ae74d3484442b7d3bb19a26": { @@ -8203,11 +8203,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e3abfde7cfd24e938684e179059edd9d", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_293230277eac481e84ab200e7cd5bdc1", "tabbable": null, "tooltip": null, - "value": "\u200726.0/26.0\u2007[00:00<00:00,\u20073.64kB/s]" + "value": " 26.0/26.0 [00:00<00:00, 3.64kB/s]" } }, "6ca82062740348f2bee12629de7f8e2f": { @@ -8350,11 +8350,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_39409ff188e6463bab5bf783828cdbd6", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7bd58f2c5b444ed9b5f21c6b364c9dce", "tabbable": null, "tooltip": null, - "value": "\u2007456k/?\u2007[00:00<00:00,\u20079.34MB/s]" + "value": " 456k/? [00:00<00:00, 9.34MB/s]" } }, "6dd11b3c888a461aaa85372b044ccd53": { @@ -8550,11 +8550,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_c08dfc48a3574ac1ba2e416960d1d3ea", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_3d29acbe122346388e66e0741971a810", "tabbable": null, "tooltip": null, - "value": "generation_config.json:\u2007100%" + "value": "generation_config.json: 100%" } }, "737d22cc16184d6a92cf045c476c7a01": { @@ -8573,11 +8573,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_8f5f5df1b0314449a113f7bb959fa273", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7671d9fee96d4e9f921046a1fb092672", "tabbable": null, "tooltip": null, - "value": "\u2007689/689\u2007[00:00<00:00,\u2007131kB/s]" + "value": " 689/689 [00:00<00:00, 131kB/s]" } }, "73cd80c764784b4197af01198ba6b886": { @@ -8596,7 +8596,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_f796802e863c48138fdcec92f546a372", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_22e77d7546334e038b64cc2c856a6a13", "tabbable": null, "tooltip": null, @@ -8619,11 +8619,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_cb86f659e5ad4c5296c32b97d99d357c", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_8dacdd0b8b5e4433ae4511433eb7df1d", "tabbable": null, "tooltip": null, - "value": "generation_config.json:\u2007100%" + "value": "generation_config.json: 100%" } }, "741e4c3c5c56426c91955f6b0622f629": { @@ -8642,11 +8642,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_5e617862d88745c792f610f6651662f6", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_60e5aae936bb41faad191cf2155f78d3", "tabbable": null, "tooltip": null, - "value": "tokenizer_config.json:\u2007100%" + "value": "tokenizer_config.json: 100%" } }, "75b76bb0ff2f491a8e5febeb8166cbd2": { @@ -8813,11 +8813,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_07149da010c5489696b03653df25cdd2", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_8ccfc9585b794ba293cbe376311c42ba", "tabbable": null, "tooltip": null, - "value": "\u2007144/144\u2007[00:07<00:00,\u200717.62it/s]" + "value": " 144/144 [00:07<00:00, 17.62it/s]" } }, "775073e43a7a4b2f970c810d2a05c73e": { @@ -9414,11 +9414,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_cfee0a94009c461dbdca5b73961f7fbe", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_f732f9391dd14bcaace6fd5c27a8335a", "tabbable": null, "tooltip": null, - "value": "\u2007144/144\u2007[00:07<00:00,\u200718.66it/s]" + "value": " 144/144 [00:07<00:00, 18.66it/s]" } }, "84867a129d0043b4910ac244e8a984df": { @@ -10089,11 +10089,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_a996f052d0ed4d938f0c71fc72e8c1b6", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_2a6fd373a6524eb3ae033ea78d3cb61e", "tabbable": null, "tooltip": null, - "value": "config.json:\u2007100%" + "value": "config.json: 100%" } }, "9157c241f7064a8596e1ffaeb850e59c": { @@ -10205,7 +10205,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_c51496d64234439ebbfec98b59f44803", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_29bf4e8f0e6042b498c6ac50d8fedf68", "tabbable": null, "tooltip": null, @@ -10508,11 +10508,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d95b7b9cf6914acb9d1152502d2ba41b", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_d7c69daa5fa44a6487f5dc66380ec31a", "tabbable": null, "tooltip": null, - "value": "\u2007180/180\u2007[00:10<00:00,\u200718.55it/s]" + "value": " 180/180 [00:10<00:00, 18.55it/s]" } }, "9b50599c4a084d9eb5b8755040c9bd32": { @@ -10547,11 +10547,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_f5bbf314d840422b9e486386de3f5bb6", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_dc68cf2bf3e94df881377a106754a350", "tabbable": null, "tooltip": null, - "value": "\u20071.04M/?\u2007[00:00<00:00,\u200710.5MB/s]" + "value": " 1.04M/? [00:00<00:00, 10.5MB/s]" } }, "9ba87779605c4969bf48b4071a94c630": { @@ -10671,11 +10671,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_363c2c2e96624875af87c420c7e2cf95", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_b732ea0e03674d4384ac0d2dbf2a5f69", "tabbable": null, "tooltip": null, - "value": "config.json:\u2007100%" + "value": "config.json: 100%" } }, "a0df7ae7fcc1441a8c9cca5a80b539b0": { @@ -10694,11 +10694,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_bcaf8ebca41240bd86dde0c617b68f01", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_dd6bf89931a64c63bbcd2cf526835c2d", "tabbable": null, "tooltip": null, - "value": "\u20072160/2160\u2007[02:00<00:00,\u200718.64it/s]" + "value": " 2160/2160 [02:00<00:00, 18.64it/s]" } }, "a1f007e8fb68491daa3ba444ca49f505": { @@ -10809,11 +10809,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_7655d70259dd4dbc8bd4d288f8850b7c", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_eb49abf5e8ea41e69c1307f78fda4a90", "tabbable": null, "tooltip": null, - "value": "\u2007180/180\u2007[00:09<00:00,\u200718.56it/s]" + "value": " 180/180 [00:09<00:00, 18.56it/s]" } }, "a711143026bc46a5b1b7bc3dccca1850": { @@ -10858,11 +10858,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_2970477ecd6545b2bb698748d4019dac", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7c2906dcc3d34ae6846332eb5375cc58", "tabbable": null, "tooltip": null, - "value": "\u2007124/124\u2007[00:00<00:00,\u200773.2kB/s]" + "value": " 124/124 [00:00<00:00, 73.2kB/s]" } }, "a92811787eb84dd19d9ec2fb2eab7eee": { @@ -11381,11 +11381,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_b75438a235a243d49404467d54b63373", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_8f8ffe07f8314800ad637194c8c7d10f", "tabbable": null, "tooltip": null, - "value": "\u2007665/665\u2007[00:00<00:00,\u2007122kB/s]" + "value": " 665/665 [00:00<00:00, 122kB/s]" } }, "b59d1c8089e04592a1b87a7c198d1f6c": { @@ -11481,7 +11481,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_97f6ad1918334a3ab503d4a5da11c9ef", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_0e453235a18e4f5c9008040b1420f718", "tabbable": null, "tooltip": null, @@ -11609,11 +11609,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_45e840f42f8c46d491678fad7820e835", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_e29dd892780a4208b61593998e588e1f", "tabbable": null, "tooltip": null, - "value": "model.safetensors:\u2007100%" + "value": "model.safetensors: 100%" } }, "bb91ef19e455404aac3d283f868f9687": { @@ -12050,11 +12050,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_2452dd39a79742b29964500360f4a478", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_6ca82062740348f2bee12629de7f8e2f", "tabbable": null, "tooltip": null, - "value": "\u2007580/580\u2007[00:00<00:00,\u20075679.21it/s,\u2007Materializing\u2007param=transformer.wte.weight]" + "value": " 580/580 [00:00<00:00, 5679.21it/s, Materializing param=transformer.wte.weight]" } }, "c455478a557645b29777950e364a5006": { @@ -12258,11 +12258,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_cb791fd8c7ae49079f9162125e98ff79", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_52cd22d9c796437692165d3f3ed48e82", "tabbable": null, "tooltip": null, - "value": "\u20072160/2160\u2007[02:00<00:00,\u200715.47it/s]" + "value": " 2160/2160 [02:00<00:00, 15.47it/s]" } }, "c5bca44eefb940d39fba70d4fff71571": { @@ -12358,11 +12358,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_f44f01a296fb4d198718574f1f802ba2", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_c6b1b39c22564a59bc5b950d7a46708f", "tabbable": null, "tooltip": null, - "value": "\u200726.0/26.0\u2007[00:00<00:00,\u20074.24kB/s]" + "value": " 26.0/26.0 [00:00<00:00, 4.24kB/s]" } }, "c6b1b39c22564a59bc5b950d7a46708f": { @@ -12417,11 +12417,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_7d0ba24e89554742a562ce50135447f0", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_502e8cbf7b7c47b480b8a85405fc24cf", "tabbable": null, "tooltip": null, - "value": "\u20072160/2160\u2007[02:00<00:00,\u200717.95it/s]" + "value": " 2160/2160 [02:00<00:00, 17.95it/s]" } }, "c836782daaa847248009a626db347182": { @@ -12562,11 +12562,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_3d83725b10254139a51cb688a495459d", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_020cf001eb7d496295a325cbc0ee8718", "tabbable": null, "tooltip": null, - "value": "\u20072160/2160\u2007[02:00<00:00,\u200718.75it/s]" + "value": " 2160/2160 [02:00<00:00, 18.75it/s]" } }, "cac816368dee481a9fee6b196a2b16d6": { @@ -12760,7 +12760,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_72035846c38c4a439583cc5e974c0a52", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_7aebe1ba464949849da0e182d90d0669", "tabbable": null, "tooltip": null, @@ -12986,11 +12986,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_8ab4ec47b46e401883c185417c452f17", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_fdb4d541f8414fe4bee9265554cd7522", "tabbable": null, "tooltip": null, - "value": "tokenizer_config.json:\u2007100%" + "value": "tokenizer_config.json: 100%" } }, "d3dab66a1c254f07afa02e73e6fd121d": { @@ -13692,11 +13692,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_b41e54c58689400b86fc0dbf18e4bbaa", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_98bce027ddf046c29bbbb03c6e9b1de3", "tabbable": null, "tooltip": null, - "value": "\u2007144/144\u2007[00:08<00:00,\u200718.70it/s]" + "value": " 144/144 [00:08<00:00, 18.70it/s]" } }, "df2f48a5055b4cb7a9db8115c407fd8c": { @@ -13892,11 +13892,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_29376b858cd4489a8fcefc2b096df1e5", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_6f7575088f10441c87ded2f68ac37e9f", "tabbable": null, "tooltip": null, - "value": "\u20072160/2160\u2007[01:59<00:00,\u200718.72it/s]" + "value": " 2160/2160 [01:59<00:00, 18.72it/s]" } }, "e3abfde7cfd24e938684e179059edd9d": { @@ -14219,11 +14219,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_db91637067cb428c94674237eabda8f7", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_dd98180b8cba40aa82fc6221dd4676d0", "tabbable": null, "tooltip": null, - "value": "\u2007180/180\u2007[00:09<00:00,\u200717.32it/s]" + "value": " 180/180 [00:09<00:00, 17.32it/s]" } }, "eb49abf5e8ea41e69c1307f78fda4a90": { @@ -14764,7 +14764,7 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e2c844c07e434e718186afaf72312371", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_3704e0be72444fd9a07038fbe1b19156", "tabbable": null, "tooltip": null, @@ -14840,11 +14840,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_61ffe9ae7ab44e94b5729cae80e49437", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_a314648aea3d40f89ac931c47f81c9e6", "tabbable": null, "tooltip": null, - "value": "\u2007148/148\u2007[00:00<00:00,\u20075498.19it/s,\u2007Materializing\u2007param=transformer.wte.weight]" + "value": " 148/148 [00:00<00:00, 5498.19it/s, Materializing param=transformer.wte.weight]" } }, "fd150a5176074e959dfa52a35770b5f0": { @@ -14863,11 +14863,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_a05625d67e634f2a80c65cdcfcbe8f8c", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_d1e752f79bbc40ddbae7a02895c9b74e", "tabbable": null, "tooltip": null, - "value": "\u20076.43G/6.43G\u2007[04:52<00:00,\u2007111MB/s]" + "value": " 6.43G/6.43G [04:52<00:00, 111MB/s]" } }, "fdb4d541f8414fe4bee9265554cd7522": { diff --git a/demos/Exploratory_Analysis_Demo.ipynb b/demos/Exploratory_Analysis_Demo.ipynb index c2c86cea8..097f25c10 100644 --- a/demos/Exploratory_Analysis_Demo.ipynb +++ b/demos/Exploratory_Analysis_Demo.ipynb @@ -117,7 +117,7 @@ "from IPython.display import HTML, IFrame\n", "from jaxtyping import Float\n", "\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens import ActivationCache\n", "from transformer_lens.model_bridge import TransformerBridge" ] @@ -1883,4 +1883,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/demos/Grokking_Demo.ipynb b/demos/Grokking_Demo.ipynb index 7b7fe5243..09af51e9a 100644 --- a/demos/Grokking_Demo.ipynb +++ b/demos/Grokking_Demo.ipynb @@ -138,7 +138,7 @@ "outputs": [], "source": [ "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import (\n", " HookedRootModule,\n", " HookPoint,\n", @@ -2921,10 +2921,10 @@ "evalue": "name 'train_losses' is not defined", "output_type": "error", "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", - "\u001B[0;32m/tmp/ipykernel_1229617/2975677256.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mneel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mplot\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0mnpx\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m \u001B[0mfig\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnpx\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mline\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mtrain_losses\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;36m100\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtest_losses\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;36m100\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mx\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0marange\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtrain_losses\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m100\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mxaxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m\"Epoch\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0myaxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m\"Loss\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlog_y\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtitle\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m\"Training Curve for Modular Addition\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mline_labels\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m'train'\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m'test'\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtoggle_x\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtoggle_y\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mreturn_fig\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 3\u001B[0m \u001B[0madd_lines\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mfig\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", - "\u001B[0;31mNameError\u001B[0m: name 'train_losses' is not defined" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_1229617/2975677256.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mneel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnpx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnpx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtrain_losses\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_losses\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_losses\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxaxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Epoch\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myaxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Loss\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_y\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtitle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Training Curve for Modular Addition\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mline_labels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'test'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoggle_x\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoggle_y\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_fig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0madd_lines\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'train_losses' is not defined" ] } ], @@ -3500,11 +3500,11 @@ "evalue": "Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1", "output_type": "error", "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mRuntimeError\u001B[0m Traceback (most recent call last)", - "\u001B[0;32m/tmp/ipykernel_1215793/3004607503.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mprint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mloss_fn\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mall_logits\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlabels\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m", - "\u001B[0;32m/tmp/ipykernel_1215793/4096650173.py\u001B[0m in \u001B[0;36mloss_fn\u001B[0;34m(logits, labels)\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mlogits\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mlogits\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mto\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mfloat64\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0mlog_probs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mlogits\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlog_softmax\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdim\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 6\u001B[0;31m \u001B[0mcorrect_log_probs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mlog_probs\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgather\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdim\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mindex\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mlabels\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 7\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0;34m-\u001B[0m\u001B[0mcorrect_log_probs\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmean\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 8\u001B[0m \u001B[0mtrain_logits\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtrain_data\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", - "\u001B[0;31mRuntimeError\u001B[0m: Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_1215793/3004607503.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/tmp/ipykernel_1215793/4096650173.py\u001b[0m in \u001b[0;36mloss_fn\u001b[0;34m(logits, labels)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlog_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mcorrect_log_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlog_probs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcorrect_log_probs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mtrain_logits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1" ] } ], diff --git a/demos/Head_Detector_Demo.ipynb b/demos/Head_Detector_Demo.ipynb index 5354abeb8..2caa4231f 100644 --- a/demos/Head_Detector_Demo.ipynb +++ b/demos/Head_Detector_Demo.ipynb @@ -321,7 +321,7 @@ "import numpy as np\n", "import torch\n", "\n", - "# from transformer_lens.utils import is_lower_triangular, is_square\n", + "# from transformer_lens.utilities import is_lower_triangular, is_square\n", "\n", "HeadName = Literal[\"previous_token_head\", \"duplicate_token_head\", \"induction_head\"]\n", "HEAD_NAMES = cast(List[HeadName], get_args(HeadName))\n", @@ -393,7 +393,7 @@ " --------\n", " .. code-block:: python\n", "\n", - " >>> from transformer_lens import utils\n", + " >>> from transformer_lens import utilities as utils\n", " >>> from transformer_lens.model_bridge import TransformerBridge\n", " >>> from transformer_lens.head_detector import detect_head\n", " >>> import plotly.express as px\n", @@ -1476,7 +1476,7 @@ "output_type": "stream", "text": [ "\r", - " 3%|\u258e | 3/100 [00:00<00:03, 27.52it/s]" + " 3%|▎ | 3/100 [00:00<00:03, 27.52it/s]" ] }, { @@ -1484,7 +1484,7 @@ "output_type": "stream", "text": [ "\r", - " 6%|\u258c | 6/100 [00:00<00:03, 27.79it/s]" + " 6%|▌ | 6/100 [00:00<00:03, 27.79it/s]" ] }, { @@ -1492,7 +1492,7 @@ "output_type": "stream", "text": [ "\r", - " 9%|\u2589 | 9/100 [00:00<00:03, 27.62it/s]" + " 9%|▉ | 9/100 [00:00<00:03, 27.62it/s]" ] }, { @@ -1500,7 +1500,7 @@ "output_type": "stream", "text": [ "\r", - " 12%|\u2588\u258f | 12/100 [00:00<00:03, 27.87it/s]" + " 12%|█▏ | 12/100 [00:00<00:03, 27.87it/s]" ] }, { @@ -1508,7 +1508,7 @@ "output_type": "stream", "text": [ "\r", - " 15%|\u2588\u258c | 15/100 [00:00<00:03, 22.61it/s]" + " 15%|█▌ | 15/100 [00:00<00:03, 22.61it/s]" ] }, { @@ -1516,7 +1516,7 @@ "output_type": "stream", "text": [ "\r", - " 18%|\u2588\u258a | 18/100 [00:00<00:03, 23.94it/s]" + " 18%|█▊ | 18/100 [00:00<00:03, 23.94it/s]" ] }, { @@ -1524,7 +1524,7 @@ "output_type": "stream", "text": [ "\r", - " 21%|\u2588\u2588 | 21/100 [00:00<00:03, 25.03it/s]" + " 21%|██ | 21/100 [00:00<00:03, 25.03it/s]" ] }, { @@ -1532,7 +1532,7 @@ "output_type": "stream", "text": [ "\r", - " 24%|\u2588\u2588\u258d | 24/100 [00:00<00:02, 26.01it/s]" + " 24%|██▍ | 24/100 [00:00<00:02, 26.01it/s]" ] }, { @@ -1540,7 +1540,7 @@ "output_type": "stream", "text": [ "\r", - " 27%|\u2588\u2588\u258b | 27/100 [00:01<00:02, 26.76it/s]" + " 27%|██▋ | 27/100 [00:01<00:02, 26.76it/s]" ] }, { @@ -1548,7 +1548,7 @@ "output_type": "stream", "text": [ "\r", - " 30%|\u2588\u2588\u2588 | 30/100 [00:01<00:02, 27.30it/s]" + " 30%|███ | 30/100 [00:01<00:02, 27.30it/s]" ] }, { @@ -1556,7 +1556,7 @@ "output_type": "stream", "text": [ "\r", - " 33%|\u2588\u2588\u2588\u258e | 33/100 [00:01<00:02, 27.58it/s]" + " 33%|███▎ | 33/100 [00:01<00:02, 27.58it/s]" ] }, { @@ -1564,7 +1564,7 @@ "output_type": "stream", "text": [ "\r", - " 36%|\u2588\u2588\u2588\u258c | 36/100 [00:01<00:02, 27.98it/s]" + " 36%|███▌ | 36/100 [00:01<00:02, 27.98it/s]" ] }, { @@ -1572,7 +1572,7 @@ "output_type": "stream", "text": [ "\r", - " 39%|\u2588\u2588\u2588\u2589 | 39/100 [00:01<00:02, 23.28it/s]" + " 39%|███▉ | 39/100 [00:01<00:02, 23.28it/s]" ] }, { @@ -1580,7 +1580,7 @@ "output_type": "stream", "text": [ "\r", - " 42%|\u2588\u2588\u2588\u2588\u258f | 42/100 [00:01<00:02, 24.58it/s]" + " 42%|████▏ | 42/100 [00:01<00:02, 24.58it/s]" ] }, { @@ -1588,7 +1588,7 @@ "output_type": "stream", "text": [ "\r", - " 45%|\u2588\u2588\u2588\u2588\u258c | 45/100 [00:01<00:02, 25.73it/s]" + " 45%|████▌ | 45/100 [00:01<00:02, 25.73it/s]" ] }, { @@ -1596,7 +1596,7 @@ "output_type": "stream", "text": [ "\r", - " 48%|\u2588\u2588\u2588\u2588\u258a | 48/100 [00:01<00:01, 26.74it/s]" + " 48%|████▊ | 48/100 [00:01<00:01, 26.74it/s]" ] }, { @@ -1604,7 +1604,7 @@ "output_type": "stream", "text": [ "\r", - " 51%|\u2588\u2588\u2588\u2588\u2588 | 51/100 [00:01<00:01, 27.57it/s]" + " 51%|█████ | 51/100 [00:01<00:01, 27.57it/s]" ] }, { @@ -1612,7 +1612,7 @@ "output_type": "stream", "text": [ "\r", - " 54%|\u2588\u2588\u2588\u2588\u2588\u258d | 54/100 [00:02<00:01, 27.88it/s]" + " 54%|█████▍ | 54/100 [00:02<00:01, 27.88it/s]" ] }, { @@ -1620,7 +1620,7 @@ "output_type": "stream", "text": [ "\r", - " 57%|\u2588\u2588\u2588\u2588\u2588\u258b | 57/100 [00:02<00:01, 28.25it/s]" + " 57%|█████▋ | 57/100 [00:02<00:01, 28.25it/s]" ] }, { @@ -1628,7 +1628,7 @@ "output_type": "stream", "text": [ "\r", - " 60%|\u2588\u2588\u2588\u2588\u2588\u2588 | 60/100 [00:02<00:01, 28.54it/s]" + " 60%|██████ | 60/100 [00:02<00:01, 28.54it/s]" ] }, { @@ -1636,7 +1636,7 @@ "output_type": "stream", "text": [ "\r", - " 63%|\u2588\u2588\u2588\u2588\u2588\u2588\u258e | 63/100 [00:02<00:01, 28.49it/s]" + " 63%|██████▎ | 63/100 [00:02<00:01, 28.49it/s]" ] }, { @@ -1644,7 +1644,7 @@ "output_type": "stream", "text": [ "\r", - " 66%|\u2588\u2588\u2588\u2588\u2588\u2588\u258c | 66/100 [00:02<00:01, 24.36it/s]" + " 66%|██████▌ | 66/100 [00:02<00:01, 24.36it/s]" ] }, { @@ -1652,7 +1652,7 @@ "output_type": "stream", "text": [ "\r", - " 69%|\u2588\u2588\u2588\u2588\u2588\u2588\u2589 | 69/100 [00:02<00:01, 25.76it/s]" + " 69%|██████▉ | 69/100 [00:02<00:01, 25.76it/s]" ] }, { @@ -1660,7 +1660,7 @@ "output_type": "stream", "text": [ "\r", - " 72%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258f | 72/100 [00:02<00:01, 26.88it/s]" + " 72%|███████▏ | 72/100 [00:02<00:01, 26.88it/s]" ] }, { @@ -1668,7 +1668,7 @@ "output_type": "stream", "text": [ "\r", - " 75%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258c | 75/100 [00:02<00:00, 27.68it/s]" + " 75%|███████▌ | 75/100 [00:02<00:00, 27.68it/s]" ] }, { @@ -1676,7 +1676,7 @@ "output_type": "stream", "text": [ "\r", - " 78%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258a | 78/100 [00:02<00:00, 28.29it/s]" + " 78%|███████▊ | 78/100 [00:02<00:00, 28.29it/s]" ] }, { @@ -1684,7 +1684,7 @@ "output_type": "stream", "text": [ "\r", - " 81%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 81/100 [00:03<00:00, 28.48it/s]" + " 81%|████████ | 81/100 [00:03<00:00, 28.48it/s]" ] }, { @@ -1692,7 +1692,7 @@ "output_type": "stream", "text": [ "\r", - " 84%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258d | 84/100 [00:03<00:00, 27.51it/s]" + " 84%|████████▍ | 84/100 [00:03<00:00, 27.51it/s]" ] }, { @@ -1700,7 +1700,7 @@ "output_type": "stream", "text": [ "\r", - " 87%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258b | 87/100 [00:03<00:00, 27.90it/s]" + " 87%|████████▋ | 87/100 [00:03<00:00, 27.90it/s]" ] }, { @@ -1708,7 +1708,7 @@ "output_type": "stream", "text": [ "\r", - " 90%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 90/100 [00:03<00:00, 23.68it/s]" + " 90%|█████████ | 90/100 [00:03<00:00, 23.68it/s]" ] }, { @@ -1716,7 +1716,7 @@ "output_type": "stream", "text": [ "\r", - " 93%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258e| 93/100 [00:03<00:00, 25.09it/s]" + " 93%|█████████▎| 93/100 [00:03<00:00, 25.09it/s]" ] }, { @@ -1724,7 +1724,7 @@ "output_type": "stream", "text": [ "\r", - " 96%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258c| 96/100 [00:03<00:00, 26.30it/s]" + " 96%|█████████▌| 96/100 [00:03<00:00, 26.30it/s]" ] }, { @@ -1732,7 +1732,7 @@ "output_type": "stream", "text": [ "\r", - " 99%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2589| 99/100 [00:03<00:00, 27.08it/s]" + " 99%|█████████▉| 99/100 [00:03<00:00, 27.08it/s]" ] }, { @@ -1740,7 +1740,7 @@ "output_type": "stream", "text": [ "\r", - "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 100/100 [00:03<00:00, 26.54it/s]" + "100%|██████████| 100/100 [00:03<00:00, 26.54it/s]" ] }, { @@ -2217,11 +2217,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e16620f692214737b99cd956859b0c34", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_fc412d50f34a4568a7c286ed2be51fc9", "tabbable": null, "tooltip": null, - "value": "\u2007148/148\u2007[00:00<00:00,\u20075290.11it/s,\u2007Materializing\u2007param=transformer.wte.weight]" + "value": " 148/148 [00:00<00:00, 5290.11it/s, Materializing param=transformer.wte.weight]" } }, "6662c4f913f74e6181469daa1b9e5ed3": { @@ -2293,11 +2293,11 @@ "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e3770c75a2f74f2e966c9b7b8b397bf2", - "placeholder": "\u200b", + "placeholder": "​", "style": "IPY_MODEL_535051899ad84087afbaf824a2f7bb73", "tabbable": null, "tooltip": null, - "value": "Loading\u2007weights:\u2007100%" + "value": "Loading weights: 100%" } }, "ded58218d5854ad7be3ffcc9f0165c31": { diff --git a/demos/Interactive_Neuroscope.ipynb b/demos/Interactive_Neuroscope.ipynb index eb4931de1..c372aa4b0 100644 --- a/demos/Interactive_Neuroscope.ipynb +++ b/demos/Interactive_Neuroscope.ipynb @@ -75,7 +75,7 @@ "source": [ "import gradio as gr\n", "from transformer_lens.model_bridge import TransformerBridge\n", - "from transformer_lens.utils import to_numpy\n", + "from transformer_lens.utilities import to_numpy\n", "from IPython.display import HTML" ] }, @@ -446,7 +446,10 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "# NBVAL_SKIP\ndemo.launch(share=True, height=1000)" + "source": [ + "# NBVAL_SKIP\n", + "demo.launch(share=True, height=1000)" + ] } ], "metadata": { @@ -471,4 +474,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/demos/LLaMA.ipynb b/demos/LLaMA.ipynb index faf1fed52..61d0e54d3 100644 --- a/demos/LLaMA.ipynb +++ b/demos/LLaMA.ipynb @@ -99,7 +99,7 @@ "from jaxtyping import Float\n", "\n", "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import (\n", " HookPoint,\n", ") # Hooking utilities\n", diff --git a/demos/LLaMA2_GPU_Quantized.ipynb b/demos/LLaMA2_GPU_Quantized.ipynb index 3129e6489..4722428a9 100644 --- a/demos/LLaMA2_GPU_Quantized.ipynb +++ b/demos/LLaMA2_GPU_Quantized.ipynb @@ -123,7 +123,7 @@ "from jaxtyping import Float\n", "\n", "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import (\n", " HookPoint,\n", ") # Hooking utilities\n", diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index 543d3dcac..3c4e72dc5 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -133,7 +133,7 @@ "outputs": [], "source": [ "# import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import (\n", " HookPoint,\n", ") # Hooking utilities\n", diff --git a/demos/No_Position_Experiment.ipynb b/demos/No_Position_Experiment.ipynb index 95fe03659..979c99e42 100644 --- a/demos/No_Position_Experiment.ipynb +++ b/demos/No_Position_Experiment.ipynb @@ -1,1383 +1,1383 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - " \"Open\n", - "" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Introduction\n", - "\n", - "The accompanying notebook to my [real-time research](https://www.youtube.com/watch?v=yo4QvDn-vsU) video. Trains a model with no positional embeddings to predict the previous token, and makes a start at analysing what's going on there!\n", - "\n", - "EDIT: The loss spikes were due to the learning rate being max(step/100, 1.0) not min! Thanks to MadHatter for catching that." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running as a Jupyter notebook - intended for development only!\n" - ] - } - ], - "source": [ - "# NBVAL_IGNORE_OUTPUT\n", - "import os\n", - "\n", - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", - "DEVELOPMENT_MODE = False\n", - "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", - "try:\n", - " import google.colab\n", - "\n", - " IN_COLAB = True\n", - " print(\"Running as a Colab notebook\")\n", - "except:\n", - " IN_COLAB = False\n", - " print(\"Running as a Jupyter notebook - intended for development only!\")\n", - "\n", - "if IN_COLAB or IN_GITHUB:\n", - " %pip install einops\n", - " %pip install transformer_lens@v1.15.0\n", - "\n", - " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", - " # # Install another version of node that makes PySvelte work way faster\n", - " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", - "\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig\n", - "import torch\n", - "import numpy as np\n", - "import plotly.express as px\n", - "import plotly.io as pio\n", - "\n", - "pio.renderers.default = \"colab\"\n", - "import tqdm.auto as tqdm\n", - "import einops\n", - "from transformer_lens.utils import to_numpy\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Some plotting code. Wrappers around Plotly, not important to understand." - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [], - "source": [ - "def line(tensor, line_labels=None, yaxis=\"\", xaxis=\"\", **kwargs):\n", - " tensor = to_numpy(tensor)\n", - " labels = {\"y\": yaxis, \"x\": xaxis}\n", - " fig = px.line(tensor, labels=labels, **kwargs)\n", - " if line_labels:\n", - " for c, label in enumerate(line_labels):\n", - " fig.data[c].name = label\n", - " fig.show()\n", - "\n", - "\n", - "def imshow(tensor, yaxis=\"\", xaxis=\"\", **kwargs):\n", - " tensor = to_numpy(tensor)\n", - " plot_kwargs = {\n", - " \"color_continuous_scale\": \"RdBu\",\n", - " \"color_continuous_midpoint\": 0.0,\n", - " \"labels\": {\"x\": xaxis, \"y\": yaxis},\n", - " }\n", - " plot_kwargs.update(kwargs)\n", - " px.imshow(tensor, **plot_kwargs).show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Model Training" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Defining the Model" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "cfg = HookedTransformerConfig(\n", - " n_layers=2,\n", - " d_model=64,\n", - " d_head=64,\n", - " n_heads=1,\n", - " d_mlp=256,\n", - " d_vocab=300,\n", - " n_ctx=50,\n", - " act_fn=\"relu\",\n", - " normalization_type=\"LN\",\n", - " device=device,\n", - ")\n", - "model = HookedTransformer(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [], - "source": [ - "def deactivate_position(model):\n", - " model.pos_embed.W_pos.data[:] = 0.0\n", - " model.pos_embed.W_pos.requires_grad = False\n", - "\n", - "\n", - "deactivate_position(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "HookedTransformer(\n", - " (embed): Embed()\n", - " (hook_embed): HookPoint()\n", - " (pos_embed): PosEmbed()\n", - " (hook_pos_embed): HookPoint()\n", - " (blocks): ModuleList(\n", - " (0-1): 2 x TransformerBlock(\n", - " (ln1): LayerNorm(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (ln2): LayerNorm(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (attn): Attention(\n", - " (hook_k): HookPoint()\n", - " (hook_q): HookPoint()\n", - " (hook_v): HookPoint()\n", - " (hook_z): HookPoint()\n", - " (hook_attn_scores): HookPoint()\n", - " (hook_pattern): HookPoint()\n", - " (hook_result): HookPoint()\n", - " )\n", - " (mlp): MLP(\n", - " (hook_pre): HookPoint()\n", - " (hook_post): HookPoint()\n", - " )\n", - " (hook_attn_in): HookPoint()\n", - " (hook_q_input): HookPoint()\n", - " (hook_k_input): HookPoint()\n", - " (hook_v_input): HookPoint()\n", - " (hook_mlp_in): HookPoint()\n", - " (hook_attn_out): HookPoint()\n", - " (hook_mlp_out): HookPoint()\n", - " (hook_resid_pre): HookPoint()\n", - " (hook_resid_mid): HookPoint()\n", - " (hook_resid_post): HookPoint()\n", - " )\n", - " )\n", - " (ln_final): LayerNorm(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (unembed): Unembed()\n", - ")\n" - ] - } - ], - "source": [ - "print(model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Define data + Loss function" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ 0, 93, 34, 155, 274, 116, 114, 248, 68, 3, 298, 83, 194, 20,\n", - " 8, 133, 32, 66, 62, 73, 210, 273, 46, 243, 104, 232, 161, 125,\n", - " 123, 251, 7, 4, 115, 127, 21, 1, 89, 142, 6, 15, 298, 251,\n", - " 88, 229, 108, 114, 23, 88, 3, 265],\n", - " [ 0, 118, 46, 274, 105, 268, 131, 35, 19, 58, 226, 278, 27, 25,\n", - " 276, 180, 164, 4, 95, 27, 74, 201, 105, 65, 80, 185, 44, 258,\n", - " 105, 60, 58, 47, 126, 60, 294, 253, 258, 136, 29, 101, 258, 77,\n", - " 80, 180, 159, 169, 122, 117, 27, 194]])\n" - ] - } - ], - "source": [ - "def make_data_generator(cfg, batch_size, seed=123, incl_bos_token=True):\n", - " torch.manual_seed(seed)\n", - " while True:\n", - " x = torch.randint(1, cfg.d_vocab, (batch_size, cfg.n_ctx))\n", - " if incl_bos_token:\n", - " x[:, 0] = 0\n", - " yield x\n", - "\n", - "\n", - "data_generator = make_data_generator(cfg, 2)\n", - "print(next(data_generator))" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [], - "source": [ - "def loss_fn(logits, tokens, per_token=False):\n", - " # logit shape: [batch, pos, vocab]\n", - " # token shape: [batch, pos]\n", - " logits = logits[:, 1:]\n", - " tokens = tokens[:, :-1]\n", - " log_probs = logits.log_softmax(-1)\n", - " correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]\n", - " if per_token:\n", - " return -correct_log_probs\n", - " else:\n", - " return -correct_log_probs.mean()" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[0.0004, 0.0003, 0.0031, 0.0005]])\n", - "tensor(0.0011)\n" - ] - } - ], - "source": [ - "# Test the loss function works\n", - "test_tokens = torch.arange(5)[None, :]\n", - "test_logits = torch.randn(1, 5, 10)\n", - "test_logits[:, 1, 0] = 10.0\n", - "test_logits[:, 2, 1] = 10.0\n", - "test_logits[:, 3, 2] = 10.0\n", - "test_logits[:, 4, 3] = 10.0\n", - "print(loss_fn(test_logits, test_tokens, per_token=True))\n", - "print(loss_fn(test_logits, test_tokens, per_token=False))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Setup Optimizer\n" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [], - "source": [ - "batch_size = 256\n", - "num_epochs = 4000\n", - "lr = 1e-4\n", - "betas = (0.9, 0.95)\n", - "max_grad_norm = 1.0\n", - "wd = 0.1\n", - "optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)\n", - "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))\n", - "\n", - "data_loader = make_data_generator(cfg, batch_size)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model Training" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "122c183908104b04a600bfe4aca9f009", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/4000 [00:00\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "losses = []\n", - "for epoch in tqdm.tqdm(range(num_epochs)):\n", - " tokens = next(data_loader)\n", - " tokens = tokens.to(device)\n", - " logits = model(tokens)\n", - " loss = loss_fn(logits, tokens)\n", - " loss.backward()\n", - " if max_grad_norm is not None:\n", - " torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n", - " optimizer.step()\n", - " optimizer.zero_grad()\n", - " scheduler.step()\n", - " losses.append(loss.item())\n", - " if epoch % 100 == 0:\n", - " print(f\"Epoch {epoch}: {loss.item()}\")\n", - "px.line(losses, labels={\"x\": \"Epoch\", \"y\": \"Loss\"})" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [], - "source": [ - "# torch.save(model.state_dict(), \"no_pos_experiment_state_dict_v0.pth\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Model Interpretability" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(0.)" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.pos_embed.W_pos.norm()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Look at attention patterns" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loss: 0.003689224598929286\n" - ] - } - ], - "source": [ - "big_data_loader = make_data_generator(cfg, 4000)\n", - "big_tokens = next(big_data_loader)\n", - "big_tokens = big_tokens.to(device)\n", - "logits, cache = model.run_with_cache(big_tokens)\n", - "print(\"Loss:\", loss_fn(logits, big_tokens).item())" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized']\n" - ] - } - ], - "source": [ - "print(cache)" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([4000, 1, 50, 50])" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cache[\"blocks.0.attn.hook_pattern\"].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "batch_index = 0\n", - "tokens = big_tokens[batch_index]\n", - "imshow(\n", - " to_numpy(cache[\"attn\", 0].mean([0, 1])),\n", - " title=\"Layer 0 Attention Pattern\",\n", - " height=500,\n", - " width=500,\n", - ")\n", - "imshow(\n", - " to_numpy(cache[\"attn\", 1].mean([0, 1])),\n", - " title=\"Layer 1 Attention Pattern\",\n", - " height=500,\n", - " width=500,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Look at how different bits of the model directly contribute to the logits" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([5, 4000, 50, 64])\n" - ] - } - ], - "source": [ - "resid_components = [\n", - " cache[\"embed\"],\n", - " cache[\"attn_out\", 0],\n", - " cache[\"mlp_out\", 0],\n", - " cache[\"attn_out\", 1],\n", - " cache[\"mlp_out\", 1],\n", - "]\n", - "labels = [\"embed\", \"A0\", \"M0\", \"A1\", \"M2\"]\n", - "resid_stack = torch.stack(resid_components, 0)\n", - "resid_stack = resid_stack - resid_stack.mean(-1, keepdim=True)\n", - "print(resid_stack.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([5, 50, 300])\n" - ] - } - ], - "source": [ - "fold_W_U = model.ln_final.w[:, None] * model.unembed.W_U\n", - "logit_components = resid_stack[:, batch_index] @ fold_W_U / cache[\"scale\"][batch_index]\n", - "print(logit_components.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "logit_components = logit_components - logit_components.mean(-1, keepdim=True)\n", - "line(\n", - " logit_components[:, torch.arange(1, model.cfg.n_ctx).to(device), tokens[:-1]].T,\n", - " line_labels=labels,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Folding In LayerNorm" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [], - "source": [ - "analysis_cfg = HookedTransformerConfig(\n", - " n_layers=2,\n", - " d_model=64,\n", - " d_head=64,\n", - " n_heads=1,\n", - " d_mlp=256,\n", - " d_vocab=300,\n", - " n_ctx=50,\n", - " act_fn=\"relu\",\n", - " normalization_type=\"LNPre\",\n", - " init_weights=False,\n", - ")\n", - "analysis_model = HookedTransformer(analysis_cfg)\n", - "state_dict = model.state_dict()\n", - "analysis_model.load_and_process_state_dict(\n", - " state_dict, fold_ln=True, center_writing_weights=True, center_unembed=True\n", - ")\n", - "deactivate_position(analysis_model)" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [], - "source": [ - "# analysis_model()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Understand Attn 0\n" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "QK = model.W_E @ model.W_Q[0, 0] @ model.W_K[0, 0].T @ model.W_E.T\n", - "imshow(QK, yaxis=\"Query\", xaxis=\"Key\")" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "OV = model.W_E @ model.W_V[0, 0] @ model.W_O[0, 0] @ model.W_in[0]\n", - "imshow(OV, yaxis=\"Input Vocab\", xaxis=\"Neuron\")" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "line(OV[:, torch.randint(0, 256, (5,))])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Understand MLP 0" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(cache[\"post\", 0][batch_index], yaxis=\"Pos\", xaxis=\"Neuron\")\n", - "imshow(cache[\"post\", 0].mean(0), yaxis=\"Pos\", xaxis=\"Neuron\")\n", - "imshow((cache[\"post\", 0] > 0).float()[batch_index], yaxis=\"Pos\", xaxis=\"Neuron\")\n", - "imshow((cache[\"post\", 0] > 0).float().mean(0), yaxis=\"Pos\", xaxis=\"Neuron\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Understand Attn 1" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Understand MLP 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Experiment" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Baseline loss: 0.0036276562605053186\n" - ] - } - ], - "source": [ - "new_token_batch = next(big_data_loader).to(device)\n", - "baseline_loss = loss_fn(model(new_token_batch), new_token_batch).item()\n", - "print(\"Baseline loss:\", baseline_loss)" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hook_embed 10.975448608398438\n", - "blocks.0.ln1.hook_scale 0.4518754482269287\n", - "blocks.0.ln1.hook_normalized 2.4589312076568604\n", - "blocks.0.ln2.hook_scale 0.02511436492204666\n", - "blocks.0.ln2.hook_normalized 9.506545066833496\n", - "blocks.0.attn.hook_k 0.026024164631962776\n", - "blocks.0.attn.hook_q 1.9935917854309082\n", - "blocks.0.attn.hook_v 0.19012247025966644\n", - "blocks.0.attn.hook_z 2.4572794437408447\n", - "blocks.0.attn.hook_attn_scores 1.9351273775100708\n", - "blocks.0.attn.hook_pattern 1.9483344554901123\n", - "blocks.0.mlp.hook_pre 9.506546020507812\n", - "blocks.0.mlp.hook_post 9.526301383972168\n", - "blocks.0.hook_attn_out 2.4572784900665283\n", - "blocks.0.hook_mlp_out 9.526301383972168\n", - "blocks.0.hook_resid_pre 10.975448608398438\n", - "blocks.0.hook_resid_mid 11.21129035949707\n", - "blocks.0.hook_resid_post 10.834088325500488\n", - "blocks.1.ln1.hook_scale 0.021276870742440224\n", - "blocks.1.ln1.hook_normalized 9.080503463745117\n", - "blocks.1.ln2.hook_scale 0.003745849709957838\n", - "blocks.1.ln2.hook_normalized 4.472580432891846\n", - "blocks.1.attn.hook_k 0.0021377610974013805\n", - "blocks.1.attn.hook_q 0.016889303922653198\n", - "blocks.1.attn.hook_v 9.080748558044434\n", - "blocks.1.attn.hook_z 9.080577850341797\n", - "blocks.1.attn.hook_attn_scores 0.0017388787819072604\n", - "blocks.1.attn.hook_pattern 0.0018037232803180814\n", - "blocks.1.mlp.hook_pre 4.472580432891846\n", - "blocks.1.mlp.hook_post 4.4686713218688965\n", - "blocks.1.hook_attn_out 9.080577850341797\n", - "blocks.1.hook_mlp_out 4.4686713218688965\n", - "blocks.1.hook_resid_pre 10.834088325500488\n", - "blocks.1.hook_resid_mid 10.262277603149414\n", - "blocks.1.hook_resid_post 11.917344093322754\n", - "ln_final.hook_scale 0.009998265653848648\n", - "ln_final.hook_normalized 5.719696521759033\n" - ] - } - ], - "source": [ - "hook_list = list(model.hook_dict.keys())\n", - "losses = []\n", - "loss_labels = []\n", - "for hook_name in hook_list:\n", - " if (\n", - " hook_name in cache\n", - " and hook_name != \"hook_pos_embed\"\n", - " and \"result\" not in hook_name\n", - " ):\n", - " average_act = cache[hook_name].mean(0)\n", - "\n", - " def replacing_with_average_act(activation, hook):\n", - " activation[:] = einops.repeat(\n", - " average_act, \"... -> batch ...\", batch=new_token_batch.size(0)\n", - " )\n", - " return activation\n", - "\n", - " logits = model.run_with_hooks(\n", - " new_token_batch, fwd_hooks=[(hook_name, replacing_with_average_act)]\n", - " )\n", - " loss = loss_fn(logits, new_token_batch)\n", - " print(hook_name, loss.item())\n", - " losses.append(loss.item())\n", - " loss_labels.append(hook_name)" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "
\n", - "
\n", - "\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "line(losses, hover_name=loss_labels)" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized'])" - ] - }, - "execution_count": 61, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cache.cache_dict.keys()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.7.13 ('base')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - }, - "vscode": { - "interpreter": { - "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" - } - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Introduction\n", + "\n", + "The accompanying notebook to my [real-time research](https://www.youtube.com/watch?v=yo4QvDn-vsU) video. Trains a model with no positional embeddings to predict the previous token, and makes a start at analysing what's going on there!\n", + "\n", + "EDIT: The loss spikes were due to the learning rate being max(step/100, 1.0) not min! Thanks to MadHatter for catching that." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import os\n", + "\n", + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "DEVELOPMENT_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "try:\n", + " import google.colab\n", + "\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install einops\n", + " %pip install transformer_lens@v1.15.0\n", + "\n", + " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", + " # # Install another version of node that makes PySvelte work way faster\n", + " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + "\n", + "from transformer_lens import HookedTransformer, HookedTransformerConfig\n", + "import torch\n", + "import numpy as np\n", + "import plotly.express as px\n", + "import plotly.io as pio\n", + "\n", + "pio.renderers.default = \"colab\"\n", + "import tqdm.auto as tqdm\n", + "import einops\n", + "from transformer_lens.utilities import to_numpy\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some plotting code. Wrappers around Plotly, not important to understand." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def line(tensor, line_labels=None, yaxis=\"\", xaxis=\"\", **kwargs):\n", + " tensor = to_numpy(tensor)\n", + " labels = {\"y\": yaxis, \"x\": xaxis}\n", + " fig = px.line(tensor, labels=labels, **kwargs)\n", + " if line_labels:\n", + " for c, label in enumerate(line_labels):\n", + " fig.data[c].name = label\n", + " fig.show()\n", + "\n", + "\n", + "def imshow(tensor, yaxis=\"\", xaxis=\"\", **kwargs):\n", + " tensor = to_numpy(tensor)\n", + " plot_kwargs = {\n", + " \"color_continuous_scale\": \"RdBu\",\n", + " \"color_continuous_midpoint\": 0.0,\n", + " \"labels\": {\"x\": xaxis, \"y\": yaxis},\n", + " }\n", + " plot_kwargs.update(kwargs)\n", + " px.imshow(tensor, **plot_kwargs).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the Model" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "cfg = HookedTransformerConfig(\n", + " n_layers=2,\n", + " d_model=64,\n", + " d_head=64,\n", + " n_heads=1,\n", + " d_mlp=256,\n", + " d_vocab=300,\n", + " n_ctx=50,\n", + " act_fn=\"relu\",\n", + " normalization_type=\"LN\",\n", + " device=device,\n", + ")\n", + "model = HookedTransformer(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "def deactivate_position(model):\n", + " model.pos_embed.W_pos.data[:] = 0.0\n", + " model.pos_embed.W_pos.requires_grad = False\n", + "\n", + "\n", + "deactivate_position(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HookedTransformer(\n", + " (embed): Embed()\n", + " (hook_embed): HookPoint()\n", + " (pos_embed): PosEmbed()\n", + " (hook_pos_embed): HookPoint()\n", + " (blocks): ModuleList(\n", + " (0-1): 2 x TransformerBlock(\n", + " (ln1): LayerNorm(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (ln2): LayerNorm(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (attn): Attention(\n", + " (hook_k): HookPoint()\n", + " (hook_q): HookPoint()\n", + " (hook_v): HookPoint()\n", + " (hook_z): HookPoint()\n", + " (hook_attn_scores): HookPoint()\n", + " (hook_pattern): HookPoint()\n", + " (hook_result): HookPoint()\n", + " )\n", + " (mlp): MLP(\n", + " (hook_pre): HookPoint()\n", + " (hook_post): HookPoint()\n", + " )\n", + " (hook_attn_in): HookPoint()\n", + " (hook_q_input): HookPoint()\n", + " (hook_k_input): HookPoint()\n", + " (hook_v_input): HookPoint()\n", + " (hook_mlp_in): HookPoint()\n", + " (hook_attn_out): HookPoint()\n", + " (hook_mlp_out): HookPoint()\n", + " (hook_resid_pre): HookPoint()\n", + " (hook_resid_mid): HookPoint()\n", + " (hook_resid_post): HookPoint()\n", + " )\n", + " )\n", + " (ln_final): LayerNorm(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (unembed): Unembed()\n", + ")\n" + ] + } + ], + "source": [ + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define data + Loss function" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 0, 93, 34, 155, 274, 116, 114, 248, 68, 3, 298, 83, 194, 20,\n", + " 8, 133, 32, 66, 62, 73, 210, 273, 46, 243, 104, 232, 161, 125,\n", + " 123, 251, 7, 4, 115, 127, 21, 1, 89, 142, 6, 15, 298, 251,\n", + " 88, 229, 108, 114, 23, 88, 3, 265],\n", + " [ 0, 118, 46, 274, 105, 268, 131, 35, 19, 58, 226, 278, 27, 25,\n", + " 276, 180, 164, 4, 95, 27, 74, 201, 105, 65, 80, 185, 44, 258,\n", + " 105, 60, 58, 47, 126, 60, 294, 253, 258, 136, 29, 101, 258, 77,\n", + " 80, 180, 159, 169, 122, 117, 27, 194]])\n" + ] + } + ], + "source": [ + "def make_data_generator(cfg, batch_size, seed=123, incl_bos_token=True):\n", + " torch.manual_seed(seed)\n", + " while True:\n", + " x = torch.randint(1, cfg.d_vocab, (batch_size, cfg.n_ctx))\n", + " if incl_bos_token:\n", + " x[:, 0] = 0\n", + " yield x\n", + "\n", + "\n", + "data_generator = make_data_generator(cfg, 2)\n", + "print(next(data_generator))" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "def loss_fn(logits, tokens, per_token=False):\n", + " # logit shape: [batch, pos, vocab]\n", + " # token shape: [batch, pos]\n", + " logits = logits[:, 1:]\n", + " tokens = tokens[:, :-1]\n", + " log_probs = logits.log_softmax(-1)\n", + " correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]\n", + " if per_token:\n", + " return -correct_log_probs\n", + " else:\n", + " return -correct_log_probs.mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[0.0004, 0.0003, 0.0031, 0.0005]])\n", + "tensor(0.0011)\n" + ] + } + ], + "source": [ + "# Test the loss function works\n", + "test_tokens = torch.arange(5)[None, :]\n", + "test_logits = torch.randn(1, 5, 10)\n", + "test_logits[:, 1, 0] = 10.0\n", + "test_logits[:, 2, 1] = 10.0\n", + "test_logits[:, 3, 2] = 10.0\n", + "test_logits[:, 4, 3] = 10.0\n", + "print(loss_fn(test_logits, test_tokens, per_token=True))\n", + "print(loss_fn(test_logits, test_tokens, per_token=False))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup Optimizer\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 256\n", + "num_epochs = 4000\n", + "lr = 1e-4\n", + "betas = (0.9, 0.95)\n", + "max_grad_norm = 1.0\n", + "wd = 0.1\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)\n", + "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))\n", + "\n", + "data_loader = make_data_generator(cfg, batch_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Training" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "122c183908104b04a600bfe4aca9f009", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/4000 [00:00\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "losses = []\n", + "for epoch in tqdm.tqdm(range(num_epochs)):\n", + " tokens = next(data_loader)\n", + " tokens = tokens.to(device)\n", + " logits = model(tokens)\n", + " loss = loss_fn(logits, tokens)\n", + " loss.backward()\n", + " if max_grad_norm is not None:\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " scheduler.step()\n", + " losses.append(loss.item())\n", + " if epoch % 100 == 0:\n", + " print(f\"Epoch {epoch}: {loss.item()}\")\n", + "px.line(losses, labels={\"x\": \"Epoch\", \"y\": \"Loss\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "# torch.save(model.state_dict(), \"no_pos_experiment_state_dict_v0.pth\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Interpretability" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.)" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.pos_embed.W_pos.norm()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Look at attention patterns" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss: 0.003689224598929286\n" + ] + } + ], + "source": [ + "big_data_loader = make_data_generator(cfg, 4000)\n", + "big_tokens = next(big_data_loader)\n", + "big_tokens = big_tokens.to(device)\n", + "logits, cache = model.run_with_cache(big_tokens)\n", + "print(\"Loss:\", loss_fn(logits, big_tokens).item())" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized']\n" + ] + } + ], + "source": [ + "print(cache)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4000, 1, 50, 50])" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cache[\"blocks.0.attn.hook_pattern\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "batch_index = 0\n", + "tokens = big_tokens[batch_index]\n", + "imshow(\n", + " to_numpy(cache[\"attn\", 0].mean([0, 1])),\n", + " title=\"Layer 0 Attention Pattern\",\n", + " height=500,\n", + " width=500,\n", + ")\n", + "imshow(\n", + " to_numpy(cache[\"attn\", 1].mean([0, 1])),\n", + " title=\"Layer 1 Attention Pattern\",\n", + " height=500,\n", + " width=500,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Look at how different bits of the model directly contribute to the logits" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([5, 4000, 50, 64])\n" + ] + } + ], + "source": [ + "resid_components = [\n", + " cache[\"embed\"],\n", + " cache[\"attn_out\", 0],\n", + " cache[\"mlp_out\", 0],\n", + " cache[\"attn_out\", 1],\n", + " cache[\"mlp_out\", 1],\n", + "]\n", + "labels = [\"embed\", \"A0\", \"M0\", \"A1\", \"M2\"]\n", + "resid_stack = torch.stack(resid_components, 0)\n", + "resid_stack = resid_stack - resid_stack.mean(-1, keepdim=True)\n", + "print(resid_stack.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([5, 50, 300])\n" + ] + } + ], + "source": [ + "fold_W_U = model.ln_final.w[:, None] * model.unembed.W_U\n", + "logit_components = resid_stack[:, batch_index] @ fold_W_U / cache[\"scale\"][batch_index]\n", + "print(logit_components.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "logit_components = logit_components - logit_components.mean(-1, keepdim=True)\n", + "line(\n", + " logit_components[:, torch.arange(1, model.cfg.n_ctx).to(device), tokens[:-1]].T,\n", + " line_labels=labels,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Folding In LayerNorm" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_cfg = HookedTransformerConfig(\n", + " n_layers=2,\n", + " d_model=64,\n", + " d_head=64,\n", + " n_heads=1,\n", + " d_mlp=256,\n", + " d_vocab=300,\n", + " n_ctx=50,\n", + " act_fn=\"relu\",\n", + " normalization_type=\"LNPre\",\n", + " init_weights=False,\n", + ")\n", + "analysis_model = HookedTransformer(analysis_cfg)\n", + "state_dict = model.state_dict()\n", + "analysis_model.load_and_process_state_dict(\n", + " state_dict, fold_ln=True, center_writing_weights=True, center_unembed=True\n", + ")\n", + "deactivate_position(analysis_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "# analysis_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understand Attn 0\n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "QK = model.W_E @ model.W_Q[0, 0] @ model.W_K[0, 0].T @ model.W_E.T\n", + "imshow(QK, yaxis=\"Query\", xaxis=\"Key\")" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "OV = model.W_E @ model.W_V[0, 0] @ model.W_O[0, 0] @ model.W_in[0]\n", + "imshow(OV, yaxis=\"Input Vocab\", xaxis=\"Neuron\")" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "line(OV[:, torch.randint(0, 256, (5,))])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understand MLP 0" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(cache[\"post\", 0][batch_index], yaxis=\"Pos\", xaxis=\"Neuron\")\n", + "imshow(cache[\"post\", 0].mean(0), yaxis=\"Pos\", xaxis=\"Neuron\")\n", + "imshow((cache[\"post\", 0] > 0).float()[batch_index], yaxis=\"Pos\", xaxis=\"Neuron\")\n", + "imshow((cache[\"post\", 0] > 0).float().mean(0), yaxis=\"Pos\", xaxis=\"Neuron\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understand Attn 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understand MLP 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Experiment" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Baseline loss: 0.0036276562605053186\n" + ] + } + ], + "source": [ + "new_token_batch = next(big_data_loader).to(device)\n", + "baseline_loss = loss_fn(model(new_token_batch), new_token_batch).item()\n", + "print(\"Baseline loss:\", baseline_loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hook_embed 10.975448608398438\n", + "blocks.0.ln1.hook_scale 0.4518754482269287\n", + "blocks.0.ln1.hook_normalized 2.4589312076568604\n", + "blocks.0.ln2.hook_scale 0.02511436492204666\n", + "blocks.0.ln2.hook_normalized 9.506545066833496\n", + "blocks.0.attn.hook_k 0.026024164631962776\n", + "blocks.0.attn.hook_q 1.9935917854309082\n", + "blocks.0.attn.hook_v 0.19012247025966644\n", + "blocks.0.attn.hook_z 2.4572794437408447\n", + "blocks.0.attn.hook_attn_scores 1.9351273775100708\n", + "blocks.0.attn.hook_pattern 1.9483344554901123\n", + "blocks.0.mlp.hook_pre 9.506546020507812\n", + "blocks.0.mlp.hook_post 9.526301383972168\n", + "blocks.0.hook_attn_out 2.4572784900665283\n", + "blocks.0.hook_mlp_out 9.526301383972168\n", + "blocks.0.hook_resid_pre 10.975448608398438\n", + "blocks.0.hook_resid_mid 11.21129035949707\n", + "blocks.0.hook_resid_post 10.834088325500488\n", + "blocks.1.ln1.hook_scale 0.021276870742440224\n", + "blocks.1.ln1.hook_normalized 9.080503463745117\n", + "blocks.1.ln2.hook_scale 0.003745849709957838\n", + "blocks.1.ln2.hook_normalized 4.472580432891846\n", + "blocks.1.attn.hook_k 0.0021377610974013805\n", + "blocks.1.attn.hook_q 0.016889303922653198\n", + "blocks.1.attn.hook_v 9.080748558044434\n", + "blocks.1.attn.hook_z 9.080577850341797\n", + "blocks.1.attn.hook_attn_scores 0.0017388787819072604\n", + "blocks.1.attn.hook_pattern 0.0018037232803180814\n", + "blocks.1.mlp.hook_pre 4.472580432891846\n", + "blocks.1.mlp.hook_post 4.4686713218688965\n", + "blocks.1.hook_attn_out 9.080577850341797\n", + "blocks.1.hook_mlp_out 4.4686713218688965\n", + "blocks.1.hook_resid_pre 10.834088325500488\n", + "blocks.1.hook_resid_mid 10.262277603149414\n", + "blocks.1.hook_resid_post 11.917344093322754\n", + "ln_final.hook_scale 0.009998265653848648\n", + "ln_final.hook_normalized 5.719696521759033\n" + ] + } + ], + "source": [ + "hook_list = list(model.hook_dict.keys())\n", + "losses = []\n", + "loss_labels = []\n", + "for hook_name in hook_list:\n", + " if (\n", + " hook_name in cache\n", + " and hook_name != \"hook_pos_embed\"\n", + " and \"result\" not in hook_name\n", + " ):\n", + " average_act = cache[hook_name].mean(0)\n", + "\n", + " def replacing_with_average_act(activation, hook):\n", + " activation[:] = einops.repeat(\n", + " average_act, \"... -> batch ...\", batch=new_token_batch.size(0)\n", + " )\n", + " return activation\n", + "\n", + " logits = model.run_with_hooks(\n", + " new_token_batch, fwd_hooks=[(hook_name, replacing_with_average_act)]\n", + " )\n", + " loss = loss_fn(logits, new_token_batch)\n", + " print(hook_name, loss.item())\n", + " losses.append(loss.item())\n", + " loss_labels.append(hook_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "line(losses, hover_name=loss_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized'])" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cache.cache_dict.keys()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.13 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + }, + "vscode": { + "interpreter": { + "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/demos/Othello_GPT.ipynb b/demos/Othello_GPT.ipynb index d69f1a166..40469f357 100644 --- a/demos/Othello_GPT.ipynb +++ b/demos/Othello_GPT.ipynb @@ -180,7 +180,7 @@ "outputs": [], "source": [ "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import (\n", " HookedRootModule,\n", " HookPoint,\n", @@ -281,7 +281,7 @@ "metadata": {}, "outputs": [], "source": [ - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "\n", "cfg = HookedTransformerConfig(\n", " n_layers=8,\n", diff --git a/demos/SVD_Interpreter_Demo.ipynb b/demos/SVD_Interpreter_Demo.ipynb index 82b85a06e..4a2f93a69 100644 --- a/demos/SVD_Interpreter_Demo.ipynb +++ b/demos/SVD_Interpreter_Demo.ipynb @@ -119,7 +119,7 @@ "import pysvelte\n", "import numpy as np\n", "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens import HookedTransformer, SVDInterpreter" ] }, diff --git a/demos/Santa_Coder.ipynb b/demos/Santa_Coder.ipynb index af5a5b0cb..99f7dcab0 100644 --- a/demos/Santa_Coder.ipynb +++ b/demos/Santa_Coder.ipynb @@ -91,7 +91,7 @@ "# import circuitsvis as cv\n", "\n", "import transformer_lens\n", - "import transformer_lens.utils as utils\n", + "import transformer_lens.utilities as utils\n", "from transformer_lens.model_bridge import TransformerBridge\n", "\n", "_ = torch.set_grad_enabled(False)\n", diff --git a/tests/unit/pretrained_weight_conversions/test_apertus.py b/tests/unit/pretrained_weight_conversions/test_apertus.py index d7e5760e1..4e57d3757 100644 --- a/tests/unit/pretrained_weight_conversions/test_apertus.py +++ b/tests/unit/pretrained_weight_conversions/test_apertus.py @@ -183,7 +183,7 @@ def test_zero_biases_have_correct_device(self): "blocks.0.mlp.b_out", "unembed.b_U", ]: - assert sd[key].device.type == str(cfg.device), f"{key} on wrong device" + assert sd[key].device.type == cfg.device, f"{key} on wrong device" def test_unembed_shapes(self): cfg = make_cfg() diff --git a/tests/unit/utilities/test_devices.py b/tests/unit/utilities/test_devices.py index 5e1af5632..e04cedc80 100644 --- a/tests/unit/utilities/test_devices.py +++ b/tests/unit/utilities/test_devices.py @@ -43,7 +43,7 @@ def test_get_device_cuda_available(): with patch("torch.cuda.is_available", return_value=True): with patch("torch.backends.mps.is_available", return_value=False): device = get_device() - assert device == torch.device("cuda") + assert device == "cuda" @patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) @@ -54,7 +54,7 @@ def test_get_device_mps_available(): with patch("torch.backends.mps.is_built", return_value=True): with patch("torch.__version__", "2.0.0"): device = get_device() - assert device == torch.device("mps") + assert device == "mps" def test_get_device_mps_pytorch_1x(): @@ -64,7 +64,7 @@ def test_get_device_mps_pytorch_1x(): with patch("torch.backends.mps.is_built", return_value=True): with patch("torch.__version__", "1.13.0"): device = get_device() - assert device == torch.device("cpu") + assert device == "cpu" def test_get_device_cpu_fallback(): @@ -72,7 +72,7 @@ def test_get_device_cpu_fallback(): with patch("torch.cuda.is_available", return_value=False): with patch("torch.backends.mps.is_available", return_value=False): device = get_device() - assert device == torch.device("cpu") + assert device == "cpu" def test_model_with_cfg_protocol(): @@ -176,7 +176,7 @@ def test_get_device_returns_cpu_when_mps_available(mock_built, mock_avail, mock_ os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) result = get_device() - assert result == torch.device("cpu") + assert result == "cpu" @patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) @@ -186,7 +186,7 @@ def test_get_device_returns_cpu_when_mps_available(mock_built, mock_avail, mock_ def test_get_device_returns_mps_when_env_var_set(mock_built, mock_avail, mock_cuda): """get_device() should return MPS when TRANSFORMERLENS_ALLOW_MPS=1 is set.""" result = get_device() - assert result == torch.device("mps") + assert result == "mps" @patch.dict("os.environ", {}, clear=False) diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 17d8bc2fd..06e99c511 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -23,8 +23,8 @@ class first, including the examples, and then skimming the available methods. Yo from jaxtyping import Float, Int from typing_extensions import Literal -import transformer_lens.utils as utils -from transformer_lens.utils import Slice, SliceInput, warn_if_mps +import transformer_lens.utilities as utils +from transformer_lens.utilities import Slice, SliceInput, warn_if_mps class ActivationCache: diff --git a/transformer_lens/HookedEncoderDecoder.py b/transformer_lens/HookedEncoderDecoder.py index 89ad258fb..6b753690a 100644 --- a/transformer_lens/HookedEncoderDecoder.py +++ b/transformer_lens/HookedEncoderDecoder.py @@ -37,8 +37,8 @@ from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.utilities import sample_logits, warn_if_mps from transformer_lens.utilities.multi_gpu import get_device_for_block_index -from transformer_lens.utils import sample_logits, warn_if_mps T = TypeVar("T", bound="HookedEncoderDecoder") diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 17816b9c3..bb65f5f5c 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -41,7 +41,7 @@ from typing_extensions import Literal import transformer_lens.loading_from_pretrained as loading -import transformer_lens.utils as utils +import transformer_lens.utilities as utils from transformer_lens.ActivationCache import ActivationCache # Activation cache for run_with_cache; KV cache for generation @@ -63,17 +63,15 @@ from transformer_lens.hook_points import HookedRootModule, HookPoint from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES from transformer_lens.utilities import ( + USE_DEFAULT_VALUE, get_best_available_device, get_device_for_block_index, -) -from transformer_lens.utilities.devices import move_to_and_update_config -from transformer_lens.utils import ( - USE_DEFAULT_VALUE, init_kaiming_normal_, init_kaiming_uniform_, init_xavier_normal_, init_xavier_uniform_, ) +from transformer_lens.utilities.devices import move_to_and_update_config from transformer_lens.weight_processing import ProcessWeights SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 5b744ca0f..470c21231 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -18,8 +18,8 @@ from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookPoint +from transformer_lens.utilities import get_offset_position_ids from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear -from transformer_lens.utils import get_offset_position_ids if is_bitsandbytes_available(): import bitsandbytes as bnb diff --git a/transformer_lens/components/bert_block.py b/transformer_lens/components/bert_block.py index 98fd7c563..8eb0e45ac 100644 --- a/transformer_lens/components/bert_block.py +++ b/transformer_lens/components/bert_block.py @@ -13,7 +13,7 @@ from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.factories.mlp_factory import MLPFactory from transformer_lens.hook_points import HookPoint -from transformer_lens.utils import repeat_along_head_dimension +from transformer_lens.utilities import repeat_along_head_dimension class BertBlock(nn.Module): diff --git a/transformer_lens/components/pos_embed.py b/transformer_lens/components/pos_embed.py index 91d6237e5..e6319e1b3 100644 --- a/transformer_lens/components/pos_embed.py +++ b/transformer_lens/components/pos_embed.py @@ -11,7 +11,7 @@ from jaxtyping import Float, Int from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig -from transformer_lens.utils import get_offset_position_ids +from transformer_lens.utilities import get_offset_position_ids # Positional Embeddings diff --git a/transformer_lens/components/t5_block.py b/transformer_lens/components/t5_block.py index 461ac053d..e83fc7b7f 100644 --- a/transformer_lens/components/t5_block.py +++ b/transformer_lens/components/t5_block.py @@ -11,7 +11,7 @@ from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.factories.mlp_factory import MLPFactory from transformer_lens.hook_points import HookPoint -from transformer_lens.utils import repeat_along_head_dimension +from transformer_lens.utilities import repeat_along_head_dimension class T5Block(nn.Module): diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index 020e654d1..deff95b85 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -24,7 +24,7 @@ from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.factories.mlp_factory import MLPFactory from transformer_lens.hook_points import HookPoint -from transformer_lens.utils import repeat_along_head_dimension +from transformer_lens.utilities import repeat_along_head_dimension # Transformer Block diff --git a/transformer_lens/config/HookedTransformerConfig.py b/transformer_lens/config/HookedTransformerConfig.py index cacf49180..3d18cf2a7 100644 --- a/transformer_lens/config/HookedTransformerConfig.py +++ b/transformer_lens/config/HookedTransformerConfig.py @@ -334,9 +334,9 @@ def __post_init__(self): self.n_params += self.n_layers * mlp_params_per_layer if self.device is None: - self.device = str(get_device()) + self.device = get_device() else: - from transformer_lens.utils import warn_if_mps + from transformer_lens.utilities import warn_if_mps warn_if_mps(self.device) diff --git a/transformer_lens/evals.py b/transformer_lens/evals.py index b4ed61a30..5c318d013 100644 --- a/transformer_lens/evals.py +++ b/transformer_lens/evals.py @@ -14,8 +14,8 @@ from datasets import load_dataset from torch.utils.data import DataLoader, Dataset -from transformer_lens import utils -from transformer_lens.utils import warn_if_mps +from transformer_lens import utilities as utils +from transformer_lens.utilities import warn_if_mps # %% diff --git a/transformer_lens/head_detector.py b/transformer_lens/head_detector.py index fbb50fae8..9efd237ff 100644 --- a/transformer_lens/head_detector.py +++ b/transformer_lens/head_detector.py @@ -13,7 +13,7 @@ from transformer_lens.ActivationCache import ActivationCache from transformer_lens.HookedTransformer import HookedTransformer -from transformer_lens.utils import is_lower_triangular, is_square +from transformer_lens.utilities import is_lower_triangular, is_square HeadName = Literal["previous_token_head", "duplicate_token_head", "induction_head"] HEAD_NAMES = cast(List[HeadName], get_args(HeadName)) diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index a675cb736..a942ec39d 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -32,7 +32,7 @@ from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( BaseTensorConversion, ) -from transformer_lens.utils import Slice, SliceInput, warn_if_mps +from transformer_lens.utilities import Slice, SliceInput, warn_if_mps @dataclass diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 30184ec85..6c703e754 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -24,7 +24,7 @@ Wav2Vec2Model, ) -import transformer_lens.utils as utils +import transformer_lens.utilities as utils from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions import ( convert_apertus_weights, diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 050cf485a..a07a4f330 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -28,7 +28,7 @@ import torch from torch import nn -from transformer_lens import utils +from transformer_lens import utilities as utils from transformer_lens.ActivationCache import ActivationCache from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookPoint diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index ce9907384..0169da4dc 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -26,7 +26,7 @@ ) from transformer_lens.model_bridge.bridge import TransformerBridge from transformer_lens.supported_models import MODEL_ALIASES -from transformer_lens.utils import get_device, get_tokenizer_with_bos +from transformer_lens.utilities import get_device, get_tokenizer_with_bos # Suppress transformers warnings that go to stderr # This prevents notebook tests from failing due to unexpected stderr output diff --git a/transformer_lens/patching.py b/transformer_lens/patching.py index ac95bf441..868b1f215 100644 --- a/transformer_lens/patching.py +++ b/transformer_lens/patching.py @@ -60,7 +60,7 @@ from tqdm.auto import tqdm from typing_extensions import Literal -import transformer_lens.utils as utils +import transformer_lens.utilities as utils from transformer_lens.ActivationCache import ActivationCache from transformer_lens.HookedTransformer import HookedTransformer diff --git a/transformer_lens/train.py b/transformer_lens/train.py index d1eecfc12..f54f93b57 100644 --- a/transformer_lens/train.py +++ b/transformer_lens/train.py @@ -5,7 +5,7 @@ """ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch import torch.optim as optim @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm -from transformer_lens import utils +from transformer_lens import utilities as utils from transformer_lens.HookedTransformer import HookedTransformer from transformer_lens.utilities.library_utils import is_library_available @@ -50,7 +50,7 @@ class HookedTransformerTrainConfig: max_grad_norm: Optional[float] = None weight_decay: Optional[float] = None optimizer_name: str = "Adam" - device: Optional[str] = None + device: Optional[Union[str, torch.device]] = None warmup_steps: int = 0 save_every: Optional[int] = None save_dir: Optional[str] = None @@ -89,7 +89,7 @@ def train( wandb.init(project=config.wandb_project_name, config=vars(config)) if config.device is None: - config.device = str(utils.get_device()) + config.device = utils.get_device() optimizer: Optimizer if config.optimizer_name in ["Adam", "AdamW"]: diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index d0b14e4d1..ac06bd692 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -35,7 +35,7 @@ def _torch_version_tuple() -> tuple[int, ...]: # --------------------------------------------------------------------------- -def get_device() -> torch.device: +def get_device() -> str: """Get the best available device, with MPS safety checks. MPS is only auto-selected when the environment variable @@ -43,17 +43,17 @@ def get_device() -> torch.device: version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``. Returns: - torch.device: The best available device (cuda, mps, or cpu) + str: The best available device name (cuda, mps, or cpu) """ if torch.cuda.is_available(): - return torch.device("cuda") + return "cuda" if torch.backends.mps.is_available() and torch.backends.mps.is_built(): major_version = int(torch.__version__.split(".")[0]) if major_version >= 2: # Only auto-select MPS when explicitly opted-in via env var if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1": - return torch.device("mps") + return "mps" logging.info( "MPS device available but not auto-selected due to known correctness issues " "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: " @@ -61,10 +61,10 @@ def get_device() -> torch.device: torch.__version__, ) - return torch.device("cpu") + return "cpu" -def warn_if_mps(device): +def warn_if_mps(device: Union[str, torch.device]) -> None: """Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set. Automatically suppressed when the installed PyTorch version meets or exceeds @@ -129,7 +129,7 @@ def move_to_and_update_config( Returns: The model after the operation """ - from transformer_lens.utils import warn_if_mps + from transformer_lens.utilities import warn_if_mps if isinstance(device_or_dtype, torch.device): warn_if_mps(device_or_dtype) diff --git a/transformer_lens/utilities/exploratory_utils.py b/transformer_lens/utilities/exploratory_utils.py index 6cb2eb992..426aee6a0 100644 --- a/transformer_lens/utilities/exploratory_utils.py +++ b/transformer_lens/utilities/exploratory_utils.py @@ -31,13 +31,13 @@ def test_prompt( Examples: - >>> from transformer_lens import HookedTransformer, utils + >>> from transformer_lens import HookedTransformer, utilities >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") Loaded pretrained model tiny-stories-1M into HookedTransformer >>> prompt = "Why did the elephant cross the" >>> answer = "road" - >>> utils.test_prompt(prompt, answer, model) + >>> utilities.test_prompt(prompt, answer, model) Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the'] Tokenized answer: [' road'] Performance on answer token: diff --git a/transformer_lens/weight_processing.py b/transformer_lens/weight_processing.py index 1d219973a..d9de9b481 100644 --- a/transformer_lens/weight_processing.py +++ b/transformer_lens/weight_processing.py @@ -11,7 +11,7 @@ import einops import torch -import transformer_lens.utils as utils +import transformer_lens.utilities as utils from transformer_lens.config.TransformerLensConfig import TransformerLensConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter From 5d412ca287fc9d433fb8f174c64e0ef985fc9324 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 12:40:51 +0200 Subject: [PATCH 02/15] feat: Add MPS CI runner support (#1264) --- pyproject.toml | 5 +- tests/conftest.py | 8 +- tests/mps/__init__.py | 1 + tests/mps/test_mps_basic.py | 246 ++++++++++++++++++++++++++++++++++++ 4 files changed, 258 insertions(+), 2 deletions(-) create mode 100644 tests/mps/__init__.py create mode 100644 tests/mps/test_mps_basic.py diff --git a/pyproject.toml b/pyproject.toml index e20e94ce6..da70ba811 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,10 @@ "ignore:distutils Version classes are deprecated:DeprecationWarning", "ignore:pkg_resources is deprecated as an API:DeprecationWarning", ] - markers=["slow: marks tests as slow (deselect with '-m \"not slow\"')"] + markers=[ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "no_mps: marks test as incompatible with MPS device (deselect with '-m \"not no_mps\"')", + ] pythonpath=["."] testpaths=["tests", "transformer_lens"] # Only test these directories diff --git a/tests/conftest.py b/tests/conftest.py index 4bb009c8e..3229823af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,11 @@ def cleanup_memory(): """Automatically clean up memory after each test.""" yield - # Clear torch cache + # Clear torch cache for all accelerators if torch.cuda.is_available(): torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() # Force garbage collection for cleanup gc.collect() @@ -28,6 +30,8 @@ def cleanup_class_memory(): # More aggressive cleanup after test classes if torch.cuda.is_available(): torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() gc.collect() @@ -50,6 +54,8 @@ def pytest_sessionfinish(session, exitstatus): """Clean up at the end of test session.""" if torch.cuda.is_available(): torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() gc.collect() diff --git a/tests/mps/__init__.py b/tests/mps/__init__.py new file mode 100644 index 000000000..d319c86cd --- /dev/null +++ b/tests/mps/__init__.py @@ -0,0 +1 @@ +# MPS (Apple Silicon) test package diff --git a/tests/mps/test_mps_basic.py b/tests/mps/test_mps_basic.py new file mode 100644 index 000000000..7114611a4 --- /dev/null +++ b/tests/mps/test_mps_basic.py @@ -0,0 +1,246 @@ +"""Apple Silicon MPS smoke tests for TransformerLens. + +Design principles: +- All tests skip automatically on non-MPS runners (Linux, Windows, CPU-only Macs) +- Only float32 is used (bfloat16 is unsupported on MPS) +- Only small models are loaded (roneneldan/TinyStories-1M, ~50MB) +- torch.mps.empty_cache() + gc.collect() between tests to stay within memory budget +- TRANSFORMERLENS_ALLOW_MPS=1 must be set for get_device() to return "mps" + +CI: These tests are run via the `mps-checks` job in .github/workflows/checks.yml +which sets TRANSFORMERLENS_ALLOW_MPS=1 and runs on macos-latest. +""" + +import gc +import os +import warnings + +import pytest +import torch + +# Skip the entire module on non-MPS runners (Linux CI, CPU-only Macs) +pytestmark = pytest.mark.skipif( + not torch.backends.mps.is_available(), + reason="MPS not available on this runner — skipping Apple Silicon tests", +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +SMALL_MODEL = "roneneldan/TinyStories-1M" # ~50MB, safe for 1GB runner budget + + +def _load_tiny_model(device: str = "mps"): + """Load TinyStories-1M on the given device with float32 (bfloat16 unsupported on MPS).""" + from transformer_lens import HookedTransformer + + return HookedTransformer.from_pretrained(SMALL_MODEL, device=device, dtype=torch.float32) + + +def _cleanup(model=None): + """Free GPU memory between tests.""" + if model is not None: + del model + torch.mps.empty_cache() + gc.collect() + + +# --------------------------------------------------------------------------- +# 1. Device detection (no model load — instant) +# --------------------------------------------------------------------------- + + +def test_mps_device_available(): + """Sanity check: MPS backend is present and built on this runner.""" + assert torch.backends.mps.is_available(), "MPS not available" + assert torch.backends.mps.is_built(), "MPS not built into this PyTorch" + + +def test_mps_get_device_returns_mps_with_env_var(): + """get_device() auto-selects MPS when TRANSFORMERLENS_ALLOW_MPS=1 is set.""" + from transformer_lens.utilities.devices import get_device + + original = os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") + try: + os.environ["TRANSFORMERLENS_ALLOW_MPS"] = "1" + device = get_device() + assert device == "mps", f"Expected 'mps', got '{device}'" + finally: + if original: + os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original + else: + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + + +def test_mps_get_device_falls_back_to_cpu_without_env_var(): + """get_device() falls back to CPU when TRANSFORMERLENS_ALLOW_MPS is unset (safety default).""" + from transformer_lens.utilities.devices import get_device + + original = os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") + try: + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + device = get_device() + # On a Mac with no CUDA, should return cpu (safe default) + assert device in ("cpu", "mps"), f"Unexpected device: {device}" + if original == "": # env var was not set originally + assert device == "cpu", ( + "Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not 'mps'" + ) + finally: + if original: + os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original + + +def test_mps_warn_if_mps_emits_warning_without_env_var(): + """warn_if_mps() emits a UserWarning when MPS is used without the env var.""" + from transformer_lens.utilities import warn_if_mps + import transformer_lens.utilities.devices as devices_module + + original = os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") + original_warned = devices_module._mps_warned + try: + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + devices_module._mps_warned = False # reset so warning fires + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_mps("mps") + assert any("MPS backend" in str(warning.message) for warning in w), ( + "Expected MPS warning but got: " + str([str(x.message) for x in w]) + ) + finally: + if original: + os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original + devices_module._mps_warned = original_warned + + +# --------------------------------------------------------------------------- +# 2. Raw tensor operations on Metal (no model load) +# --------------------------------------------------------------------------- + + +def test_mps_tensor_basic_operations(): + """Basic tensor arithmetic runs on the Metal GPU without errors.""" + x = torch.randn(16, 32, device="mps", dtype=torch.float32) + y = torch.randn(16, 32, device="mps", dtype=torch.float32) + + z = x + y + assert z.device.type == "mps" + + w = torch.matmul(x, y.T) + assert w.device.type == "mps" + assert w.shape == (16, 16) + + # Verify result comes back to CPU correctly + z_cpu = z.cpu() + assert z_cpu.device.type == "cpu" + + _cleanup() + + +def test_mps_softmax_and_layernorm(): + """Softmax and LayerNorm — core transformer ops — work on MPS.""" + x = torch.randn(4, 16, 64, device="mps", dtype=torch.float32) + + softmax_out = torch.nn.functional.softmax(x, dim=-1) + assert softmax_out.device.type == "mps" + assert torch.allclose(softmax_out.sum(dim=-1), torch.ones(4, 16, device="mps"), atol=1e-5) + + ln = torch.nn.LayerNorm(64).to("mps") + ln_out = ln(x) + assert ln_out.device.type == "mps" + + _cleanup() + + +# --------------------------------------------------------------------------- +# 3. Model loading and forward pass on Metal +# --------------------------------------------------------------------------- + + +def test_mps_model_forward_pass(): + """TinyStories-1M loads and runs a forward pass on the Metal GPU.""" + model = _load_tiny_model(device="mps") + + tokens = model.to_tokens("Once upon a time") + assert tokens.device.type == "mps", f"Tokens should be on MPS, got {tokens.device}" + + logits = model(tokens) + assert logits.device.type == "mps", f"Logits should be on MPS, got {logits.device}" + assert logits.shape[-1] == model.cfg.d_vocab + assert not torch.isnan(logits).any(), "NaN values in logits — possible MPS compute error" + + _cleanup(model) + + +def test_mps_run_with_cache(): + """run_with_cache() returns cache tensors on the Metal GPU.""" + model = _load_tiny_model(device="mps") + tokens = model.to_tokens("The quick brown fox") + + logits, cache = model.run_with_cache(tokens) + + assert logits.device.type == "mps" + + # Check a representative set of cache keys + hook_q = cache["blocks.0.attn.hook_q"] + assert hook_q.device.type == "mps", f"Cache tensor not on MPS: {hook_q.device}" + assert not torch.isnan(hook_q).any(), "NaN in attention query cache" + + _cleanup(model) + + +def test_mps_activation_hook_fires_on_metal(): + """run_with_hooks() fires hooks and hook tensors are on the Metal GPU.""" + model = _load_tiny_model(device="mps") + tokens = model.to_tokens("Apple Silicon rocks") + + hook_devices = [] + hook_shapes = [] + + def capture_hook(value, hook): + hook_devices.append(value.device.type) + hook_shapes.append(value.shape) + return value + + model.run_with_hooks( + tokens, + fwd_hooks=[ + ("blocks.0.attn.hook_q", capture_hook), + ("blocks.0.mlp.hook_post", capture_hook), + ], + ) + + assert len(hook_devices) == 2, f"Expected 2 hooks to fire, got {len(hook_devices)}" + for device in hook_devices: + assert device == "mps", f"Hook tensor not on MPS: {device}" + + _cleanup(model) + + +def test_mps_float32_inference(): + """Explicit float32 model loads and infers correctly on MPS.""" + model = _load_tiny_model(device="mps") + + # Verify all parameters are float32 + for name, param in model.named_parameters(): + assert param.dtype == torch.float32, f"Parameter {name} has wrong dtype: {param.dtype}" + + tokens = model.to_tokens("Testing float32 on Metal") + logits = model(tokens) + assert logits.dtype == torch.float32 + + _cleanup(model) + + +def test_mps_loss_computation(): + """Loss computation (return_type='loss') works on MPS.""" + model = _load_tiny_model(device="mps") + + loss = model("Once upon a time in a land", return_type="loss") + assert isinstance(loss, torch.Tensor) + assert loss.device.type == "mps" + assert not torch.isnan(loss), f"NaN loss — possible MPS compute error: {loss}" + assert loss.item() > 0, "Loss should be positive" + + _cleanup(model) From 7d9d6df929aa07487886ccb7289e46193a311cd0 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 12:41:52 +0200 Subject: [PATCH 03/15] ci: Enable runs on feature branch --- .github/workflows/checks.yml | 53 ++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 3a2e1389c..ab4359c76 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -6,6 +6,7 @@ on: - main - dev* - refactor* + - feat/mps-ci-support tags: - "v*.*.*" paths-ignore: @@ -103,6 +104,58 @@ jobs: - name: Build check run: uv build + mps-checks: + name: MPS Checks (Apple Silicon) + runs-on: macos-latest + # Only run on PRs merging to main or pushes directly to main + # Temporarily disabled to verify on feature branch + # if: > + # (github.event_name == 'pull_request' && github.base_ref == 'main') || + # (github.event_name == 'push' && github.ref == 'refs/heads/main') + steps: + - uses: actions/checkout@v4 + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + python-version: "3.11" + activate-environment: true + enable-cache: true + - name: Cache Models used with MPS Tests + uses: actions/cache@v3 + with: + path: | + ~/.cache/huggingface/hub/models--roneneldan--TinyStories-1M* + key: ${{ runner.os }}-huggingface-models-mps-v1 + - name: Install dependencies + run: | + uv lock --check + uv sync + - name: Verify MPS availability + run: | + uv run python -c " + import torch + print(f'PyTorch: {torch.__version__}') + print(f'MPS available: {torch.backends.mps.is_available()}') + print(f'MPS built: {torch.backends.mps.is_built()}') + assert torch.backends.mps.is_available(), 'MPS not available on this runner!' + " + - name: Unit Test (MPS) + run: make unit-test + - name: Integration Test (MPS) + run: > + uv run pytest tests/integration -v + --ignore=tests/integration/model_bridge/ + --ignore=tests/integration/test_prepend_bos.py + --ignore=tests/integration/test_generation_compatibility.py + --ignore=tests/integration/test_match_huggingface.py + --ignore=tests/integration/test_fold_layer_integration.py + --ignore=tests/integration/test_centralized_weight_processing.py + --ignore=tests/integration/test_create_hooked_encoder.py + - name: MPS Smoke Tests + run: uv run pytest tests/mps -v + env: + TRANSFORMERLENS_ALLOW_MPS: "1" + format-check: name: Format Check runs-on: ubuntu-latest From 64539130b5d3e8b513c4a22684899b428eee39c1 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 13:35:52 +0200 Subject: [PATCH 04/15] fix: Skip heavy model_bridge unit tests on MPS runner due to memory limits --- .github/workflows/checks.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index ab4359c76..ac76badd0 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -140,7 +140,9 @@ jobs: assert torch.backends.mps.is_available(), 'MPS not available on this runner!' " - name: Unit Test (MPS) - run: make unit-test + run: > + uv run pytest tests/unit -v + --ignore=tests/unit/model_bridge/ - name: Integration Test (MPS) run: > uv run pytest tests/integration -v From 18b2f57f0e4658a6573f5d6a4a52d069e8258df0 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 14:15:31 +0200 Subject: [PATCH 05/15] fix: Ignore flaky grouped query attention tests on Mac runner --- .github/workflows/checks.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index ac76badd0..5de3efd49 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -149,6 +149,7 @@ jobs: --ignore=tests/integration/model_bridge/ --ignore=tests/integration/test_prepend_bos.py --ignore=tests/integration/test_generation_compatibility.py + --ignore=tests/integration/test_grouped_query_attention.py --ignore=tests/integration/test_match_huggingface.py --ignore=tests/integration/test_fold_layer_integration.py --ignore=tests/integration/test_centralized_weight_processing.py From b2787d241b7607a4827087a4e5b6906648251963 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 14:34:02 +0200 Subject: [PATCH 06/15] style: Run make format on test_mps_basic.py --- tests/mps/test_mps_basic.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/mps/test_mps_basic.py b/tests/mps/test_mps_basic.py index 7114611a4..d7a0ed0f0 100644 --- a/tests/mps/test_mps_basic.py +++ b/tests/mps/test_mps_basic.py @@ -84,9 +84,9 @@ def test_mps_get_device_falls_back_to_cpu_without_env_var(): # On a Mac with no CUDA, should return cpu (safe default) assert device in ("cpu", "mps"), f"Unexpected device: {device}" if original == "": # env var was not set originally - assert device == "cpu", ( - "Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not 'mps'" - ) + assert ( + device == "cpu" + ), "Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not 'mps'" finally: if original: os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original @@ -94,8 +94,8 @@ def test_mps_get_device_falls_back_to_cpu_without_env_var(): def test_mps_warn_if_mps_emits_warning_without_env_var(): """warn_if_mps() emits a UserWarning when MPS is used without the env var.""" - from transformer_lens.utilities import warn_if_mps import transformer_lens.utilities.devices as devices_module + from transformer_lens.utilities import warn_if_mps original = os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") original_warned = devices_module._mps_warned @@ -105,9 +105,9 @@ def test_mps_warn_if_mps_emits_warning_without_env_var(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") warn_if_mps("mps") - assert any("MPS backend" in str(warning.message) for warning in w), ( - "Expected MPS warning but got: " + str([str(x.message) for x in w]) - ) + assert any( + "MPS backend" in str(warning.message) for warning in w + ), "Expected MPS warning but got: " + str([str(x.message) for x in w]) finally: if original: os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original From 5931009d2c36ffdced4a7247d5c34dd19ee76f81 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 14:40:15 +0200 Subject: [PATCH 07/15] style: Standardize MPS step naming convention in CI --- .github/workflows/checks.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 5de3efd49..351813a61 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -105,7 +105,7 @@ jobs: run: uv build mps-checks: - name: MPS Checks (Apple Silicon) + name: MPS Checks runs-on: macos-latest # Only run on PRs merging to main or pushes directly to main # Temporarily disabled to verify on feature branch @@ -120,7 +120,7 @@ jobs: python-version: "3.11" activate-environment: true enable-cache: true - - name: Cache Models used with MPS Tests + - name: MPS Cache Models uses: actions/cache@v3 with: path: | @@ -130,7 +130,7 @@ jobs: run: | uv lock --check uv sync - - name: Verify MPS availability + - name: MPS Availability Check run: | uv run python -c " import torch @@ -139,11 +139,11 @@ jobs: print(f'MPS built: {torch.backends.mps.is_built()}') assert torch.backends.mps.is_available(), 'MPS not available on this runner!' " - - name: Unit Test (MPS) + - name: MPS Unit Tests run: > uv run pytest tests/unit -v --ignore=tests/unit/model_bridge/ - - name: Integration Test (MPS) + - name: MPS Integration Tests run: > uv run pytest tests/integration -v --ignore=tests/integration/model_bridge/ From be012404867dff41f063de05c406459cb76da4c2 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 17:21:22 +0200 Subject: [PATCH 08/15] ci: Revert MPS trigger to run only on main PRs and pushes --- .github/workflows/checks.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 351813a61..3fba72f5e 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -108,10 +108,9 @@ jobs: name: MPS Checks runs-on: macos-latest # Only run on PRs merging to main or pushes directly to main - # Temporarily disabled to verify on feature branch - # if: > - # (github.event_name == 'pull_request' && github.base_ref == 'main') || - # (github.event_name == 'push' && github.ref == 'refs/heads/main') + if: > + (github.event_name == 'pull_request' && github.base_ref == 'main') || + (github.event_name == 'push' && github.ref == 'refs/heads/main') steps: - uses: actions/checkout@v4 - name: Install uv From ed30beb8eea07a93241fa462d7f500da3e4a9e19 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 17:26:23 +0200 Subject: [PATCH 09/15] ci: Remove feature branch from global workflow triggers --- .github/workflows/checks.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 3fba72f5e..262be82d5 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -6,7 +6,6 @@ on: - main - dev* - refactor* - - feat/mps-ci-support tags: - "v*.*.*" paths-ignore: From 58ebdd3faf7abf2cbe5b2fcb807b81d5b49279b1 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 23:14:45 +0200 Subject: [PATCH 10/15] fix: Restore torch.device return type in get_device for API stability --- tests/mps/test_mps_basic.py | 10 ++++------ tests/unit/utilities/test_devices.py | 22 ++++++++++++---------- transformer_lens/utilities/devices.py | 12 ++++++------ 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/tests/mps/test_mps_basic.py b/tests/mps/test_mps_basic.py index d7a0ed0f0..7e4f74615 100644 --- a/tests/mps/test_mps_basic.py +++ b/tests/mps/test_mps_basic.py @@ -65,7 +65,8 @@ def test_mps_get_device_returns_mps_with_env_var(): try: os.environ["TRANSFORMERLENS_ALLOW_MPS"] = "1" device = get_device() - assert device == "mps", f"Expected 'mps', got '{device}'" + assert isinstance(device, torch.device) + assert device.type == "mps", f"Expected 'mps', got '{device.type}'" finally: if original: os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original @@ -82,11 +83,8 @@ def test_mps_get_device_falls_back_to_cpu_without_env_var(): os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) device = get_device() # On a Mac with no CUDA, should return cpu (safe default) - assert device in ("cpu", "mps"), f"Unexpected device: {device}" - if original == "": # env var was not set originally - assert ( - device == "cpu" - ), "Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not 'mps'" + assert isinstance(device, torch.device) + assert device.type == "cpu", f"Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not '{device.type}'" finally: if original: os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original diff --git a/tests/unit/utilities/test_devices.py b/tests/unit/utilities/test_devices.py index 3116c6c04..99a963696 100644 --- a/tests/unit/utilities/test_devices.py +++ b/tests/unit/utilities/test_devices.py @@ -43,8 +43,8 @@ def test_get_device_cuda_available(): with patch("torch.cuda.is_available", return_value=True): with patch("torch.backends.mps.is_available", return_value=False): device = get_device() - assert device == "cuda" - + assert isinstance(device, torch.device) + assert device.type == "cuda" @patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) def test_get_device_mps_available(): @@ -54,8 +54,8 @@ def test_get_device_mps_available(): with patch("torch.backends.mps.is_built", return_value=True): with patch("torch.__version__", "2.0.0"): device = get_device() - assert device == "mps" - + assert isinstance(device, torch.device) + assert device.type == "mps" def test_get_device_mps_pytorch_1x(): """Test get_device when MPS is available but PyTorch version < 2.0.""" @@ -64,16 +64,16 @@ def test_get_device_mps_pytorch_1x(): with patch("torch.backends.mps.is_built", return_value=True): with patch("torch.__version__", "1.13.0"): device = get_device() - assert device == "cpu" - + assert isinstance(device, torch.device) + assert device.type == "cpu" def test_get_device_cpu_fallback(): """Test get_device falls back to CPU when no GPU available.""" with patch("torch.cuda.is_available", return_value=False): with patch("torch.backends.mps.is_available", return_value=False): device = get_device() - assert device == "cpu" - + assert isinstance(device, torch.device) + assert device.type == "cpu" def test_model_with_cfg_protocol(): """Test that ModelWithCfg protocol is runtime checkable.""" @@ -178,7 +178,8 @@ def test_get_device_returns_cpu_when_mps_available(mock_built, mock_avail, mock_ os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) result = get_device() - assert result == "cpu" + assert isinstance(result, torch.device) + assert result.type == "cpu" @patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) @@ -188,7 +189,8 @@ def test_get_device_returns_cpu_when_mps_available(mock_built, mock_avail, mock_ def test_get_device_returns_mps_when_env_var_set(mock_built, mock_avail, mock_cuda): """get_device() should return MPS when TRANSFORMERLENS_ALLOW_MPS=1 is set.""" result = get_device() - assert result == "mps" + assert isinstance(result, torch.device) + assert result.type == "mps" @patch.dict("os.environ", {}, clear=False) diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index b923fa21e..6b23e1821 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -53,25 +53,25 @@ def _torch_mps_has_known_broken_bug() -> bool: # --------------------------------------------------------------------------- -def get_device() -> str: +def get_device() -> torch.device: """Get the best available device, with MPS safety checks. MPS is only auto-selected when the environment variable ``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch - version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``. + version is 2.0 or higher. Returns: - str: The best available device name (cuda, mps, or cpu) + torch.device: The best available device (cuda, mps, or cpu) """ if torch.cuda.is_available(): - return "cuda" + return torch.device("cuda") if torch.backends.mps.is_available() and torch.backends.mps.is_built(): major_version = int(torch.__version__.split(".")[0]) if major_version >= 2: # Only auto-select MPS when explicitly opted-in via env var if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1": - return "mps" + return torch.device("mps") logging.info( "MPS device available but not auto-selected due to known correctness issues " "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: " @@ -79,7 +79,7 @@ def get_device() -> str: torch.__version__, ) - return "cpu" + return torch.device("cpu") def warn_if_mps(device: Union[str, torch.device]) -> None: From 9e816fe479353cf90c7ec0050ea991f93d97597b Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 23:14:51 +0200 Subject: [PATCH 11/15] docs: Align train and device config docstrings with implementation --- transformer_lens/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/train.py b/transformer_lens/train.py index f54f93b57..db6f5d769 100644 --- a/transformer_lens/train.py +++ b/transformer_lens/train.py @@ -32,7 +32,7 @@ class HookedTransformerTrainConfig: max_grad_norm (float, *optional*): Maximum gradient norm to use for weight_decay (float, *optional*): Weight decay to use for training optimizer_name (str): The name of the optimizer to use - device (str, *optional*): Device to use for training + device (str or torch.device, *optional*): Device to use for training warmup_steps (int, *optional*): Number of warmup steps to use for training save_every (int, *optional*): After how many batches should a checkpoint be saved save_dir, (str, *optional*): Where to save checkpoints From 159bd8d35fa253305cdf7e96270168b3dc74ec7c Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 23:33:08 +0200 Subject: [PATCH 12/15] fix: Update device type hints to Union[str, torch.device] for consistency --- transformer_lens/HookedAudioEncoder.py | 2 +- transformer_lens/HookedEncoder.py | 2 +- transformer_lens/HookedEncoderDecoder.py | 2 +- transformer_lens/config/TransformerLensConfig.py | 2 +- transformer_lens/lit/model.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index c76f9c7b7..95494ad23 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -356,7 +356,7 @@ def from_pretrained( checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, hf_model: Optional[Any] = None, - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, move_to_device: bool = True, dtype: torch.dtype = torch.float32, **from_pretrained_kwargs: Any, diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py index 4c239f3d8..e6c29ce18 100644 --- a/transformer_lens/HookedEncoder.py +++ b/transformer_lens/HookedEncoder.py @@ -377,7 +377,7 @@ def from_pretrained( checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, hf_model: Optional[Any] = None, - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, tokenizer: Optional[Any] = None, move_to_device: bool = True, dtype: torch.dtype = torch.float32, diff --git a/transformer_lens/HookedEncoderDecoder.py b/transformer_lens/HookedEncoderDecoder.py index e683d2f91..f3e4a9402 100644 --- a/transformer_lens/HookedEncoderDecoder.py +++ b/transformer_lens/HookedEncoderDecoder.py @@ -544,7 +544,7 @@ def from_pretrained( checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, hf_model: Optional[Any] = None, - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, tokenizer: Optional[Any] = None, move_to_device: bool = True, dtype: Optional[torch.dtype] = torch.float32, diff --git a/transformer_lens/config/TransformerLensConfig.py b/transformer_lens/config/TransformerLensConfig.py index fb1f5f045..df7fd352d 100644 --- a/transformer_lens/config/TransformerLensConfig.py +++ b/transformer_lens/config/TransformerLensConfig.py @@ -59,7 +59,7 @@ class TransformerLensConfig: d_vocab: int = -1 # Device configuration - device: Optional[str] = None + device: Optional[Union[str, torch.device]] = None # Attention configuration use_attn_result: bool = False diff --git a/transformer_lens/lit/model.py b/transformer_lens/lit/model.py index b66e3ba72..42e27a1f6 100644 --- a/transformer_lens/lit/model.py +++ b/transformer_lens/lit/model.py @@ -34,7 +34,7 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Union import torch @@ -86,7 +86,7 @@ class HookedTransformerLITConfig: output_all_layers: bool = DEFAULTS.OUTPUT_ALL_LAYERS embedding_layers: Optional[List[int]] = None prepend_bos: bool = DEFAULTS.PREPEND_BOS - device: Optional[str] = None + device: Optional[Union[str, torch.device]] = None def _ensure_lit_available(): From 81d387c100749f9527d279f09a98ab7e4cd8cd6b Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 23:33:16 +0200 Subject: [PATCH 13/15] ci: Update mps-checks trigger to include dev branch --- .github/workflows/checks.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 262be82d5..c033fc713 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -106,10 +106,10 @@ jobs: mps-checks: name: MPS Checks runs-on: macos-latest - # Only run on PRs merging to main or pushes directly to main + # Only run on PRs merging to main/dev or pushes directly to main/dev if: > - (github.event_name == 'pull_request' && github.base_ref == 'main') || - (github.event_name == 'push' && github.ref == 'refs/heads/main') + (github.event_name == 'pull_request' && (github.base_ref == 'main' || github.base_ref == 'dev')) || + (github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev')) steps: - uses: actions/checkout@v4 - name: Install uv From 02ae6a0896ab650ac808c5d38830624462aea00e Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 2 May 2026 23:33:16 +0200 Subject: [PATCH 14/15] test: Update device tests for torch.device compatibility and robustness --- tests/mps/test_mps_basic.py | 4 +++- tests/unit/utilities/test_devices.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/mps/test_mps_basic.py b/tests/mps/test_mps_basic.py index 7e4f74615..efe12aada 100644 --- a/tests/mps/test_mps_basic.py +++ b/tests/mps/test_mps_basic.py @@ -84,7 +84,9 @@ def test_mps_get_device_falls_back_to_cpu_without_env_var(): device = get_device() # On a Mac with no CUDA, should return cpu (safe default) assert isinstance(device, torch.device) - assert device.type == "cpu", f"Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not '{device.type}'" + assert ( + device.type == "cpu" + ), f"Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not '{device.type}'" finally: if original: os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original diff --git a/tests/unit/utilities/test_devices.py b/tests/unit/utilities/test_devices.py index 99a963696..ef8bfb227 100644 --- a/tests/unit/utilities/test_devices.py +++ b/tests/unit/utilities/test_devices.py @@ -46,6 +46,7 @@ def test_get_device_cuda_available(): assert isinstance(device, torch.device) assert device.type == "cuda" + @patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) def test_get_device_mps_available(): """Test get_device when MPS is available, PyTorch version >= 2.0, and env var set.""" @@ -57,6 +58,7 @@ def test_get_device_mps_available(): assert isinstance(device, torch.device) assert device.type == "mps" + def test_get_device_mps_pytorch_1x(): """Test get_device when MPS is available but PyTorch version < 2.0.""" with patch("torch.cuda.is_available", return_value=False): @@ -67,6 +69,7 @@ def test_get_device_mps_pytorch_1x(): assert isinstance(device, torch.device) assert device.type == "cpu" + def test_get_device_cpu_fallback(): """Test get_device falls back to CPU when no GPU available.""" with patch("torch.cuda.is_available", return_value=False): @@ -75,6 +78,7 @@ def test_get_device_cpu_fallback(): assert isinstance(device, torch.device) assert device.type == "cpu" + def test_model_with_cfg_protocol(): """Test that ModelWithCfg protocol is runtime checkable.""" model = MockModelWithCfg() From 341ab712ef83b2579c2e7a69f924320618f03f77 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sun, 3 May 2026 00:46:47 +0200 Subject: [PATCH 15/15] ci: Restrict MPS trigger to main branch only --- .github/workflows/checks.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index c033fc713..262be82d5 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -106,10 +106,10 @@ jobs: mps-checks: name: MPS Checks runs-on: macos-latest - # Only run on PRs merging to main/dev or pushes directly to main/dev + # Only run on PRs merging to main or pushes directly to main if: > - (github.event_name == 'pull_request' && (github.base_ref == 'main' || github.base_ref == 'dev')) || - (github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev')) + (github.event_name == 'pull_request' && github.base_ref == 'main') || + (github.event_name == 'push' && github.ref == 'refs/heads/main') steps: - uses: actions/checkout@v4 - name: Install uv