Describe the bug
If you pass a 1d tensor ot tokens to TransformerBridge.run_with_cache, at least for Gemma models, this causes a crash with a vague message about mismatched attention shapes. This is technically an issue in Huggingface rather than TLens directly, but this is a change to how TLens currently works that will be surprising to users.
Code example
This colab demonstrates the bug:
https://colab.research.google.com/drive/1RtubOgn6mOC8X-Owvb4jFrIxEVHqCJOB?usp=sharing
But it's easy to show as follows:
from transformer_lens.model_bridge import TransformerBridge
model = TransformerBridge.boot_transformers("google/gemma-2-2b")
input1d = model.tokenizer.encode("this is an input", return_tensors="pt").squeeze()
model.run_with_cache(input1d)
Checklist
Describe the bug
If you pass a 1d tensor ot tokens to
TransformerBridge.run_with_cache, at least for Gemma models, this causes a crash with a vague message about mismatched attention shapes. This is technically an issue in Huggingface rather than TLens directly, but this is a change to how TLens currently works that will be surprising to users.Code example
This colab demonstrates the bug:
https://colab.research.google.com/drive/1RtubOgn6mOC8X-Owvb4jFrIxEVHqCJOB?usp=sharing
But it's easy to show as follows:
Checklist