Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/acceptance/model_bridge/test_run_with_cache_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Tests that batched run_with_cache and run_with_hooks produce correct results.

Without an attention mask, HF models attend to padding tokens and contaminate
both logits and cached activations for shorter sequences in a batch. These
tests guard against that regression.
"""

import torch


def _last_real_token_idx(bridge, tokens):
"""Find the index of the last real token for each sequence in a batch."""
if bridge.tokenizer.pad_token_id is None:
return torch.full((tokens.shape[0],), tokens.shape[1] - 1)
# With left-padding, the last real token is always at position -1
return torch.full((tokens.shape[0],), tokens.shape[1] - 1)


def test_run_with_cache_batch_matches_individual(gpt2_bridge):
"""Batched run_with_cache logits at the last real token should match per-prompt runs."""
prompts = [
"Hello, my dog is cute",
"This is a much longer text. Hello, my cat is cute",
]

# Individual runs
individual_logits = []
for p in prompts:
logits, _ = gpt2_bridge.run_with_cache(p)
individual_logits.append(logits[0, -1, :])

# Batched run
batched_logits, _ = gpt2_bridge.run_with_cache(prompts)
# With left-padding forced internally, position -1 is the last real token
for i in range(len(prompts)):
batched_last = batched_logits[i, -1, :]
assert torch.allclose(
individual_logits[i], batched_last, atol=1e-4
), f"Prompt {i} logit mismatch between individual and batched run_with_cache"


def test_run_with_hooks_batch_matches_individual(gpt2_bridge):
"""Batched run_with_hooks should produce the same hook values as per-prompt runs
(for the last real token position of each sequence)."""
prompts = [
"Hello, my dog is cute",
"This is a much longer text. Hello, my cat is cute",
]

# Capture resid_post at last layer for last token
captured_individual = []

def capture_individual(tensor, hook):
# Last token's residual
captured_individual.append(tensor[0, -1, :].detach().clone())

for p in prompts:
gpt2_bridge.run_with_hooks(
p,
fwd_hooks=[("blocks.11.hook_resid_post", capture_individual)],
)

# Batched run
captured_batched = []

def capture_batched(tensor, hook):
# For left-padded batch, last real token is at position -1 for all
for i in range(tensor.shape[0]):
captured_batched.append(tensor[i, -1, :].detach().clone())

gpt2_bridge.run_with_hooks(
prompts,
fwd_hooks=[("blocks.11.hook_resid_post", capture_batched)],
)

assert len(captured_individual) == len(captured_batched) == len(prompts)
for i in range(len(prompts)):
assert torch.allclose(
captured_individual[i], captured_batched[i], atol=1e-4
), f"Prompt {i} hook value mismatch between individual and batched run_with_hooks"
51 changes: 48 additions & 3 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,16 +1534,37 @@ def forward(
else:
kwargs.pop("one_zero_attention_mask")

# Detect batched list input that will need padding. For this case we force
# left-padding internally and auto-compute attention_mask + position_ids
# (unless the caller passed them explicitly) so pad tokens don't contaminate
# attention or position embeddings.
_is_batched_list = (
isinstance(input, list)
and len(input) > 1
and not getattr(self.cfg, "is_audio_model", False)
)

try:
if isinstance(input, (str, list)):
if getattr(self.cfg, "is_audio_model", False):
raise ValueError(
"Audio models require tensor input (raw waveform), not text. "
"Pass a torch.Tensor or use the input_values parameter."
)
input_ids = self.to_tokens(
input, prepend_bos=prepend_bos, padding_side=padding_side
)
if _is_batched_list and padding_side is None:
# Force left-padding so real tokens are flush-right.
_orig_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = "left"
try:
input_ids = self.to_tokens(
input, prepend_bos=prepend_bos, padding_side=padding_side
)
finally:
self.tokenizer.padding_side = _orig_padding_side
else:
input_ids = self.to_tokens(
input, prepend_bos=prepend_bos, padding_side=padding_side
)
else:
input_ids = input

Expand All @@ -1553,6 +1574,30 @@ def forward(
isinstance(input_ids, torch.Tensor) and input_ids.is_floating_point()
)

# Auto-compute attention_mask + position_ids for batched list input
# when the caller didn't supply them. Matches HF generation convention.
if (
_is_batched_list
and attention_mask is None
and self.tokenizer is not None
and self.tokenizer.pad_token_id is not None
and not _is_inputs_embeds
):
_prev_side = self.tokenizer.padding_side
self.tokenizer.padding_side = "left"
try:
attention_mask = utils.get_attention_mask(
self.tokenizer,
input_ids,
prepend_bos=getattr(self.cfg, "default_prepend_bos", True),
).to(self.cfg.device)
finally:
self.tokenizer.padding_side = _prev_side
if "position_ids" not in kwargs:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
kwargs["position_ids"] = position_ids

if attention_mask is not None:
kwargs["attention_mask"] = attention_mask
if kwargs.pop("use_past_kv_cache", False) or kwargs.get("use_cache", False):
Expand Down
Loading