diff --git a/tests/acceptance/model_bridge/test_run_with_cache_batch.py b/tests/acceptance/model_bridge/test_run_with_cache_batch.py new file mode 100644 index 000000000..9e163e525 --- /dev/null +++ b/tests/acceptance/model_bridge/test_run_with_cache_batch.py @@ -0,0 +1,80 @@ +"""Tests that batched run_with_cache and run_with_hooks produce correct results. + +Without an attention mask, HF models attend to padding tokens and contaminate +both logits and cached activations for shorter sequences in a batch. These +tests guard against that regression. +""" + +import torch + + +def _last_real_token_idx(bridge, tokens): + """Find the index of the last real token for each sequence in a batch.""" + if bridge.tokenizer.pad_token_id is None: + return torch.full((tokens.shape[0],), tokens.shape[1] - 1) + # With left-padding, the last real token is always at position -1 + return torch.full((tokens.shape[0],), tokens.shape[1] - 1) + + +def test_run_with_cache_batch_matches_individual(gpt2_bridge): + """Batched run_with_cache logits at the last real token should match per-prompt runs.""" + prompts = [ + "Hello, my dog is cute", + "This is a much longer text. Hello, my cat is cute", + ] + + # Individual runs + individual_logits = [] + for p in prompts: + logits, _ = gpt2_bridge.run_with_cache(p) + individual_logits.append(logits[0, -1, :]) + + # Batched run + batched_logits, _ = gpt2_bridge.run_with_cache(prompts) + # With left-padding forced internally, position -1 is the last real token + for i in range(len(prompts)): + batched_last = batched_logits[i, -1, :] + assert torch.allclose( + individual_logits[i], batched_last, atol=1e-4 + ), f"Prompt {i} logit mismatch between individual and batched run_with_cache" + + +def test_run_with_hooks_batch_matches_individual(gpt2_bridge): + """Batched run_with_hooks should produce the same hook values as per-prompt runs + (for the last real token position of each sequence).""" + prompts = [ + "Hello, my dog is cute", + "This is a much longer text. Hello, my cat is cute", + ] + + # Capture resid_post at last layer for last token + captured_individual = [] + + def capture_individual(tensor, hook): + # Last token's residual + captured_individual.append(tensor[0, -1, :].detach().clone()) + + for p in prompts: + gpt2_bridge.run_with_hooks( + p, + fwd_hooks=[("blocks.11.hook_resid_post", capture_individual)], + ) + + # Batched run + captured_batched = [] + + def capture_batched(tensor, hook): + # For left-padded batch, last real token is at position -1 for all + for i in range(tensor.shape[0]): + captured_batched.append(tensor[i, -1, :].detach().clone()) + + gpt2_bridge.run_with_hooks( + prompts, + fwd_hooks=[("blocks.11.hook_resid_post", capture_batched)], + ) + + assert len(captured_individual) == len(captured_batched) == len(prompts) + for i in range(len(prompts)): + assert torch.allclose( + captured_individual[i], captured_batched[i], atol=1e-4 + ), f"Prompt {i} hook value mismatch between individual and batched run_with_hooks" diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 93778e014..cfac0b51c 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1534,6 +1534,16 @@ def forward( else: kwargs.pop("one_zero_attention_mask") + # Detect batched list input that will need padding. For this case we force + # left-padding internally and auto-compute attention_mask + position_ids + # (unless the caller passed them explicitly) so pad tokens don't contaminate + # attention or position embeddings. + _is_batched_list = ( + isinstance(input, list) + and len(input) > 1 + and not getattr(self.cfg, "is_audio_model", False) + ) + try: if isinstance(input, (str, list)): if getattr(self.cfg, "is_audio_model", False): @@ -1541,9 +1551,20 @@ def forward( "Audio models require tensor input (raw waveform), not text. " "Pass a torch.Tensor or use the input_values parameter." ) - input_ids = self.to_tokens( - input, prepend_bos=prepend_bos, padding_side=padding_side - ) + if _is_batched_list and padding_side is None: + # Force left-padding so real tokens are flush-right. + _orig_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + try: + input_ids = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + ) + finally: + self.tokenizer.padding_side = _orig_padding_side + else: + input_ids = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + ) else: input_ids = input @@ -1553,6 +1574,30 @@ def forward( isinstance(input_ids, torch.Tensor) and input_ids.is_floating_point() ) + # Auto-compute attention_mask + position_ids for batched list input + # when the caller didn't supply them. Matches HF generation convention. + if ( + _is_batched_list + and attention_mask is None + and self.tokenizer is not None + and self.tokenizer.pad_token_id is not None + and not _is_inputs_embeds + ): + _prev_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + try: + attention_mask = utils.get_attention_mask( + self.tokenizer, + input_ids, + prepend_bos=getattr(self.cfg, "default_prepend_bos", True), + ).to(self.cfg.device) + finally: + self.tokenizer.padding_side = _prev_side + if "position_ids" not in kwargs: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + kwargs["position_ids"] = position_ids + if attention_mask is not None: kwargs["attention_mask"] = attention_mask if kwargs.pop("use_past_kv_cache", False) or kwargs.get("use_cache", False):