diff --git a/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py b/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py index 7a2e6a169..f653de956 100644 --- a/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py +++ b/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py @@ -61,3 +61,19 @@ def hook_fn(acts, hook): f"TransformerBridge run_with_cache should match manual hooks. " f"Max difference: {cache_diff:.6f}" ) + + def test_run_with_cache_accepts_1d_tensor(self, gpt2_bridge_compat_no_processing): + """1D token tensors should be auto-promoted to [1, seq], matching HookedTransformer.""" + bridge_model = gpt2_bridge_compat_no_processing + + tokens_1d = torch.tensor([1, 2, 3]) + tokens_2d = tokens_1d.unsqueeze(0) + + logits_1d, cache_1d = bridge_model.run_with_cache(tokens_1d) + logits_2d, cache_2d = bridge_model.run_with_cache(tokens_2d) + + assert logits_1d.shape == logits_2d.shape + assert torch.allclose(logits_1d, logits_2d, atol=1e-5) + assert torch.allclose( + cache_1d["blocks.0.hook_mlp_out"], cache_2d["blocks.0.hook_mlp_out"], atol=1e-5 + ) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 93778e014..60a5e73f2 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1546,6 +1546,15 @@ def forward( ) else: input_ids = input + # Promote 1D integer token tensors to 2D [batch=1, seq] to match + # HookedTransformer's contract. Float tensors (inputs_embeds, + # audio waveforms) are passed through unchanged. + if ( + isinstance(input_ids, torch.Tensor) + and input_ids.ndim == 1 + and not input_ids.is_floating_point() + ): + input_ids = input_ids.unsqueeze(0) # Detect inputs_embeds: if the tensor is floating point, it's pre-computed # embeddings (e.g., from multimodal models) rather than token IDs.