Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ jobs:
name: Notebook Checks
runs-on: ubuntu-latest
timeout-minutes: 30
env:
HAS_HF_TOKEN: ${{ secrets.HF_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
Expand All @@ -233,15 +235,19 @@ jobs:
# - "Grokking_Demo"
- "Head_Detector_Demo"
# - "Interactive_Neuroscope"
- "LLaMA"
- "LLaMA2_GPU_Quantized" # Requires quantization libs + too slow for CI timeout
- "Main_Demo"
# - "No_Position_Experiment"
- "Othello_GPT"
- "Patchscopes_Generation_Demo"
- "Santa_Coder"
# - "stable_lm"
- "T5"
requires_hf_token: [false]
include:
- notebook: "LLaMA"
requires_hf_token: true
- notebook: "LLaMA2_GPU_Quantized"
requires_hf_token: true
steps:
- uses: actions/checkout@v3
- name: Add swap space
Expand Down Expand Up @@ -278,10 +284,13 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Check Notebook Output Consistency
# Note: currently only checks notebooks we have specifically setup for this
if: ${{ !matrix.requires_hf_token || env.HAS_HF_TOKEN == 'true' }}
run: pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/${{ matrix.notebook }}.ipynb
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Skip (HF_TOKEN not available)
if: ${{ matrix.requires_hf_token && env.HAS_HF_TOKEN != 'true' }}
run: echo "Skipping ${{ matrix.notebook }} — requires HF_TOKEN for gated model access"


build-docs:
Expand Down
34 changes: 29 additions & 5 deletions demos/Activation_Patching_in_TL_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
"outputs": [],
"source": [
"import transformer_lens\n",
"import transformer_lens.utils as utils\n",
"import transformer_lens.utilities as utils\n",
"from transformer_lens.model_bridge import TransformerBridge"
]
},
Expand Down Expand Up @@ -674,7 +674,7 @@
],
"source": [
"# NBVAL_SKIP\n",
"# Heavy patching computation \u2014 too slow for CI without GPU\n",
"# Heavy patching computation too slow for CI without GPU\n",
"every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n",
"imshow(every_block_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"
]
Expand Down Expand Up @@ -885,7 +885,21 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# NBVAL_IGNORE_OUTPUT\ndef induction_loss(logits, answer_token_indices=rand_tokens_A):\n seq_len = answer_token_indices.shape[1]\n\n # logits: batch x seq_len x vocab_size\n # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n final_logits = logits[:, -seq_len:-1]\n final_log_probs = final_logits.log_softmax(-1)\n return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\nCLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\nprint(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\nCORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\nprint(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)"
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"def induction_loss(logits, answer_token_indices=rand_tokens_A):\n",
" seq_len = answer_token_indices.shape[1]\n",
"\n",
" # logits: batch x seq_len x vocab_size\n",
" # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n",
" final_logits = logits[:, -seq_len:-1]\n",
" final_log_probs = final_logits.log_softmax(-1)\n",
" return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\n",
"CLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\n",
"print(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\n",
"CORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\n",
"print(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)"
]
},
{
"cell_type": "markdown",
Expand All @@ -899,7 +913,17 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# NBVAL_SKIP\n# Heavy patching computation \u2014 too slow for CI without GPU\nevery_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\nimshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n\nif DO_SLOW_RUNS:\n every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens_induction[0]))], y=[f\"L{l}H{h}\" for l in range(attn_only.cfg.n_layers) for h in range(attn_only.cfg.n_heads)])"
"source": [
"# NBVAL_SKIP\n",
"# Heavy patching computation — too slow for CI without GPU\n",
"every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n",
"imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n",
"\n",
"if DO_SLOW_RUNS:\n",
" every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n",
" every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n",
" imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens_induction[0]))], y=[f\"L{l}H{h}\" for l in range(attn_only.cfg.n_layers) for h in range(attn_only.cfg.n_heads)])"
]
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -1289,4 +1313,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
Loading
Loading