Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/acceptance/model_bridge/compatibility/test_run_with_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,19 @@ def hook_fn(acts, hook):
f"TransformerBridge run_with_cache should match manual hooks. "
f"Max difference: {cache_diff:.6f}"
)

def test_run_with_cache_accepts_1d_tensor(self, gpt2_bridge_compat_no_processing):
"""1D token tensors should be auto-promoted to [1, seq], matching HookedTransformer."""
bridge_model = gpt2_bridge_compat_no_processing

tokens_1d = torch.tensor([1, 2, 3])
tokens_2d = tokens_1d.unsqueeze(0)

logits_1d, cache_1d = bridge_model.run_with_cache(tokens_1d)
logits_2d, cache_2d = bridge_model.run_with_cache(tokens_2d)

assert logits_1d.shape == logits_2d.shape
assert torch.allclose(logits_1d, logits_2d, atol=1e-5)
assert torch.allclose(
cache_1d["blocks.0.hook_mlp_out"], cache_2d["blocks.0.hook_mlp_out"], atol=1e-5
)
9 changes: 9 additions & 0 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,15 @@ def forward(
)
else:
input_ids = input
# Promote 1D integer token tensors to 2D [batch=1, seq] to match
# HookedTransformer's contract. Float tensors (inputs_embeds,
# audio waveforms) are passed through unchanged.
if (
isinstance(input_ids, torch.Tensor)
and input_ids.ndim == 1
and not input_ids.is_floating_point()
):
input_ids = input_ids.unsqueeze(0)

# Detect inputs_embeds: if the tensor is floating point, it's pre-computed
# embeddings (e.g., from multimodal models) rather than token IDs.
Expand Down
Loading