Skip to content

[Bug Report] Calling TransformerBridge.run_with_cache with a 1d tensor causes crash #1050

@chanind

Description

@chanind

Describe the bug

If you pass a 1d tensor ot tokens to TransformerBridge.run_with_cache, at least for Gemma models, this causes a crash with a vague message about mismatched attention shapes. This is technically an issue in Huggingface rather than TLens directly, but this is a change to how TLens currently works that will be surprising to users.

Code example

This colab demonstrates the bug:
https://colab.research.google.com/drive/1RtubOgn6mOC8X-Owvb4jFrIxEVHqCJOB?usp=sharing

But it's easy to show as follows:

from transformer_lens.model_bridge import TransformerBridge

model = TransformerBridge.boot_transformers("google/gemma-2-2b")
input1d = model.tokenizer.encode("this is an input", return_tensors="pt").squeeze()
model.run_with_cache(input1d)

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    TransformerBridgeBug specific to the new TransformerBridge system

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions