Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
98811df
3.0 CI Bugs (#1261)
jlarson4 Apr 20, 2026
9ef4e4c
fix: use cfg.dtype instead of torch.get_default_dtype for KV cache in…
davidcyze Apr 20, 2026
c1e5d4b
Fix tests broken by a local GPU (#1219)
brendanlong Apr 20, 2026
c67a0a1
fix: handle LayerNorm folding correctly in load_and_process_state_dic…
VedantMadane Apr 20, 2026
524bca9
Fix HookedTransformerConfig rotary_base types (#1231)
brendanlong Apr 20, 2026
bd67b0f
Fixed Masking in HookedTransformer.generate (#999)
tuomaso Apr 21, 2026
5fe490e
Add hooked transformer generate stream (#908)
anthonyduong9 Apr 21, 2026
0db6d22
Add py.typed for type hints (#760)
UFO-101 Apr 22, 2026
a4379f8
Created Baichuan Architecture adapter (#1262)
jlarson4 Apr 22, 2026
717899e
Make `FactoredMatrix` compatible with tensor-like arguments (#599)
JasonGross Apr 22, 2026
26d51a2
NanoGPT Conversation did not handle case when there were no biases in…
dashstander Apr 22, 2026
2e89f7f
fixed batched generation on run_with_cache and run_with_hooks on tran…
jlarson4 Apr 22, 2026
bbc3ec7
Added 1D tensor handling in line with HookedTransformer (#1266)
jlarson4 Apr 23, 2026
3c5896e
Added n_ctx override to TransformerBridge (#1269)
jlarson4 Apr 24, 2026
e4222a6
Feature/generate stream on bridge (#1268)
jlarson4 Apr 24, 2026
1607ef2
Added warnings for users attempting to use MPS with Torch 2.8 (#1271)
jlarson4 Apr 28, 2026
a92a90a
Documenting 3.1 features, adding additional context to the purpose of…
jlarson4 Apr 28, 2026
ad8e123
Improved Tokenize & Concatenate (#1273)
jlarson4 Apr 29, 2026
d95bd96
Multi-Device Processing on Bridge (#1270)
jlarson4 Apr 29, 2026
fd288dc
Adding Architecture Adapter Creation Guide to Docs (#1274)
jlarson4 Apr 29, 2026
0a5218c
Fixed Quantization bug in TransformerLens 3.0 (#1276)
jlarson4 Apr 29, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ jobs:
name: Notebook Checks
runs-on: ubuntu-latest
timeout-minutes: 30
env:
HAS_HF_TOKEN: ${{ secrets.HF_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
Expand All @@ -233,15 +235,19 @@ jobs:
# - "Grokking_Demo"
- "Head_Detector_Demo"
# - "Interactive_Neuroscope"
- "LLaMA"
- "LLaMA2_GPU_Quantized" # Requires quantization libs + too slow for CI timeout
- "Main_Demo"
# - "No_Position_Experiment"
- "Othello_GPT"
- "Patchscopes_Generation_Demo"
- "Santa_Coder"
# - "stable_lm"
- "T5"
requires_hf_token: [false]
include:
- notebook: "LLaMA"
requires_hf_token: true
- notebook: "LLaMA2_GPU_Quantized"
requires_hf_token: true
steps:
- uses: actions/checkout@v3
- name: Add swap space
Expand Down Expand Up @@ -278,10 +284,13 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Check Notebook Output Consistency
# Note: currently only checks notebooks we have specifically setup for this
if: ${{ !matrix.requires_hf_token || env.HAS_HF_TOKEN == 'true' }}
run: pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/${{ matrix.notebook }}.ipynb
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Skip (HF_TOKEN not available)
if: ${{ matrix.requires_hf_token && env.HAS_HF_TOKEN != 'true' }}
run: echo "Skipping ${{ matrix.notebook }} — requires HF_TOKEN for gated model access"


build-docs:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
logits, activations = bridge.run_with_cache("Hello World")
```

`TransformerBridge` is the recommended 3.0 path and supports 50+ architectures. The legacy `HookedTransformer.from_pretrained` API is still available through a compatibility layer but is deprecated - see the [Migrating to TransformerLens 3](https://TransformerLensOrg.github.io/TransformerLens/content/migrating_to_v3.html) guide for conversion recipes.
`TransformerBridge` is the recommended 3.0 path and supports 50+ architectures. By default it preserves raw HuggingFace weights – logits and activations match HF, *not* legacy `HookedTransformer` (which folds LayerNorm and centers weights by default). Call `bridge.enable_compatibility_mode()` after booting for HookedTransformer-equivalent numerics. The legacy `HookedTransformer.from_pretrained` API is still available but deprecated see the [Migrating to TransformerLens 3](https://TransformerLensOrg.github.io/TransformerLens/content/migrating_to_v3.html) guide.

## Key Tutorials

Expand Down
34 changes: 29 additions & 5 deletions demos/Activation_Patching_in_TL_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
"outputs": [],
"source": [
"import transformer_lens\n",
"import transformer_lens.utils as utils\n",
"import transformer_lens.utilities as utils\n",
"from transformer_lens.model_bridge import TransformerBridge"
]
},
Expand Down Expand Up @@ -674,7 +674,7 @@
],
"source": [
"# NBVAL_SKIP\n",
"# Heavy patching computation \u2014 too slow for CI without GPU\n",
"# Heavy patching computation too slow for CI without GPU\n",
"every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n",
"imshow(every_block_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"
]
Expand Down Expand Up @@ -885,7 +885,21 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# NBVAL_IGNORE_OUTPUT\ndef induction_loss(logits, answer_token_indices=rand_tokens_A):\n seq_len = answer_token_indices.shape[1]\n\n # logits: batch x seq_len x vocab_size\n # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n final_logits = logits[:, -seq_len:-1]\n final_log_probs = final_logits.log_softmax(-1)\n return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\nCLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\nprint(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\nCORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\nprint(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)"
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"def induction_loss(logits, answer_token_indices=rand_tokens_A):\n",
" seq_len = answer_token_indices.shape[1]\n",
"\n",
" # logits: batch x seq_len x vocab_size\n",
" # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n",
" final_logits = logits[:, -seq_len:-1]\n",
" final_log_probs = final_logits.log_softmax(-1)\n",
" return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\n",
"CLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\n",
"print(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\n",
"CORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\n",
"print(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)"
]
},
{
"cell_type": "markdown",
Expand All @@ -899,7 +913,17 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# NBVAL_SKIP\n# Heavy patching computation \u2014 too slow for CI without GPU\nevery_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\nimshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n\nif DO_SLOW_RUNS:\n every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens_induction[0]))], y=[f\"L{l}H{h}\" for l in range(attn_only.cfg.n_layers) for h in range(attn_only.cfg.n_heads)])"
"source": [
"# NBVAL_SKIP\n",
"# Heavy patching computation — too slow for CI without GPU\n",
"every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n",
"imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n",
"\n",
"if DO_SLOW_RUNS:\n",
" every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n",
" every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n",
" imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens_induction[0]))], y=[f\"L{l}H{h}\" for l in range(attn_only.cfg.n_layers) for h in range(attn_only.cfg.n_heads)])"
]
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -1289,4 +1313,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
Loading
Loading