diff --git a/tests/acceptance/conftest.py b/tests/acceptance/conftest.py new file mode 100644 index 000000000..50ad394a6 --- /dev/null +++ b/tests/acceptance/conftest.py @@ -0,0 +1,15 @@ +"""Shared fixtures for acceptance tests. + +Session-scoped fixtures avoid redundant model loads across test files. +All models used here must be in the CI cache (see .github/workflows/checks.yml). +""" + +import pytest + + +@pytest.fixture(scope="session") +def gpt2_model(): + """Session-scoped HookedTransformer gpt2 with default weight processing.""" + from transformer_lens import HookedTransformer + + return HookedTransformer.from_pretrained("gpt2", device="cpu") diff --git a/tests/acceptance/test_generate_batch.py b/tests/acceptance/test_generate_batch.py new file mode 100644 index 000000000..8b333d5f7 --- /dev/null +++ b/tests/acceptance/test_generate_batch.py @@ -0,0 +1,30 @@ +"""Tests that batched HookedTransformer generation matches individual generation.""" + + +def test_ht_generate_batch_matches_individual(gpt2_model): + """Batched generate() should match one-by-one generate() for left-padded inputs.""" + prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"] + individual_outputs = [gpt2_model.generate(p, verbose=False, do_sample=False) for p in prompts] + + batched_outputs = gpt2_model.generate(prompts, verbose=False, do_sample=False) + for i, prompt in enumerate(prompts): + assert ( + individual_outputs[i] == batched_outputs[i] + ), f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}" + + +def test_ht_generate_batch_without_kv_cache(gpt2_model): + """Same test with use_past_kv_cache=False.""" + prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"] + individual_outputs = [ + gpt2_model.generate(p, verbose=False, do_sample=False, use_past_kv_cache=False) + for p in prompts + ] + + batched_outputs = gpt2_model.generate( + prompts, verbose=False, do_sample=False, use_past_kv_cache=False + ) + for i, prompt in enumerate(prompts): + assert ( + individual_outputs[i] == batched_outputs[i] + ), f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}" diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 7a5001948..eaff70094 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1936,8 +1936,9 @@ def generate( implying usage of self.cfg.default_prepend_bos (default is True unless specified otherwise). Pass True or False to override the default. padding_side (Union[Literal["left", "right"], None], optional): Overrides - self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple - strings of different lengths. + self.tokenizer.padding_side. Specifies which side to pad when tokenizing + multiple strings of different lengths. For batched list inputs, left-padding + is forced internally for correct generation behavior. return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the input was ('input'). @@ -1974,13 +1975,25 @@ def generate( else: return_type = "embeds" + # initial_attention_mask is always computed so that single-prompt and + # batched generation go through the same masked code path, producing + # consistent results for the same prompt regardless of batching. + initial_attention_mask: Optional[torch.Tensor] = None + _is_batched_list = isinstance(input, list) and len(input) > 1 + if isinstance(input, (str, list)): input_type = "str" - # If text, convert to tokens (batch_size=1) assert ( self.tokenizer is not None ), "Must provide a tokenizer if passing a string to the model" - input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) + if _is_batched_list: + # Force left-padding for batched generation so real tokens + # are flush-right and logits[:, -1, :] is always correct. + input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side="left") + else: + input = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + ) elif input.ndim == 2: input_type = "tokens" else: @@ -1988,6 +2001,27 @@ def generate( input_tokens = input if input_type in ["str", "tokens"] else None batch_size, ctx_length = input.shape[0], input.shape[1] + + # Compute initial attention mask. For batched inputs with padding, + # this correctly masks pad tokens. For single/unpadded inputs, this + # is all-ones which matches the no-mask code path but ensures both + # go through the same PosEmbed/attention logic for consistency. + if input_tokens is not None and self.tokenizer is not None: + _prepend_bos = ( + self.cfg.default_prepend_bos + if prepend_bos is USE_DEFAULT_VALUE + else (False if prepend_bos is None else prepend_bos) + ) + # Temporarily set padding_side="left" so get_attention_mask + # scans for leading pads (matching the left-padded tokens). + _orig_padding_side = self.tokenizer.padding_side + if _is_batched_list: + self.tokenizer.padding_side = "left" + initial_attention_mask = utils.get_attention_mask( + self.tokenizer, input_tokens, _prepend_bos + ) + if _is_batched_list: + self.tokenizer.padding_side = _orig_padding_side device = get_device_for_block_index(0, self.cfg) input = input.to(device) if use_past_kv_cache: @@ -2062,10 +2096,20 @@ def generate( for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): pos_offset = self.get_pos_offset(past_kv_cache, batch_size) - tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int) - attention_mask = utils.get_attention_mask( - self.tokenizer, tokens, False if prepend_bos is None else prepend_bos - ).to(device) + # Extend the initial attention mask with 1s for generated tokens. + attention_mask: Optional[torch.Tensor] = None + if initial_attention_mask is not None: + n_new = len(sampled_tokens_list) + if n_new > 0: + ones = torch.ones( + batch_size, + n_new, + dtype=initial_attention_mask.dtype, + device=device, + ) + attention_mask = torch.cat([initial_attention_mask.to(device), ones], dim=1) + else: + attention_mask = initial_attention_mask.to(device) residual, shortformer_pos_embed = self.get_residual( embeds, pos_offset, @@ -2089,6 +2133,7 @@ def generate( past_kv_cache=past_kv_cache, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) else: logits = self.forward( @@ -2099,6 +2144,7 @@ def generate( past_kv_cache=past_kv_cache, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) else: # We input the entire sequence, as a [batch, pos] tensor, since we aren't using @@ -2110,6 +2156,7 @@ def generate( padding_side=padding_side, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) final_logits = logits[:, -1, :] diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index a07a4f330..93778e014 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -2121,9 +2121,9 @@ def generate( prepend_bos: Accepted for API compatibility but not applied during generation. The HF model expects tokens in its native format (tokenizer defaults). Overriding BOS can silently degrade generation quality. - padding_side: Accepted for API compatibility but not applied during generation. - The generation loop always extends tokens to the right, so overriding - initial padding_side creates inconsistent token layout. + padding_side: Which side to pad when tokenizing multiple strings of different + lengths. For batched list inputs, left-padding is forced internally for + correct generation behavior. Defaults to None (tokenizer default). return_type: The type of output to return - 'input', 'str', or 'tokens' verbose: Not used in Bridge (kept for API compatibility) output_logits: If True, return a ModelOutput with sequences and logits tuple @@ -2135,10 +2135,9 @@ def generate( Generated sequence as string, list of strings, or tensor depending on input type and return_type. If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. """ - # prepend_bos and padding_side are intentionally not applied during generation. + # prepend_bos is intentionally not applied during generation. # The HF model expects tokens in its native format. Overriding BOS can silently - # degrade quality, and overriding padding_side conflicts with the generation loop - # which always extends tokens to the right. + # degrade quality. if prepend_bos is not None: import warnings @@ -2149,27 +2148,28 @@ def generate( "resulting tensor to generate().", stacklevel=2, ) - if padding_side is not None: - import warnings - - warnings.warn( - "padding_side is ignored during TransformerBridge.generate(). " - "The generation loop extends tokens to the right regardless of initial " - "padding. To control padding, tokenize with to_tokens(padding_side=...) " - "and pass the resulting tensor to generate().", - stacklevel=2, - ) + # padding_side is handled internally: for batched list inputs, left-padding + # is forced to ensure correct generation. See _is_batched_list logic below. # Stateful dispatch is decided after input parsing so we can fall back # to hf_generate() for input types the stateful loop doesn't handle. is_stateful_model = getattr(self.cfg, "is_stateful", False) + _is_batched_list = isinstance(input, list) and len(input) > 1 + _generate_from_embeds = False if isinstance(input, str): input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) input_type = "str" elif isinstance(input, list): + # Force left-padding for batched generation so real tokens are + # flush-right and logits[:, -1, :] is always the last real token. + if _is_batched_list: + _orig_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + if _is_batched_list: + self.tokenizer.padding_side = _orig_padding_side input_type = "list" elif isinstance(input, torch.Tensor) and input.is_floating_point(): # inputs_embeds: pre-computed embeddings (e.g., from multimodal models) @@ -2307,6 +2307,30 @@ def generate( ) else: forward_kwargs: Dict[str, Any] = {} + # Compute attention mask and position_ids for batched + # inputs with padding. HF models default to all-ones + # when no mask is given, which ignores padding tokens. + if ( + _is_batched_list + and self.tokenizer is not None + and self.tokenizer.pad_token_id is not None + ): + # Temp-swap to "left" so get_attention_mask scans + # for leading pads (matching the left-padded tokens). + _prev_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + attn_mask = utils.get_attention_mask( + self.tokenizer, + current_tokens, + prepend_bos=getattr(self.cfg, "default_prepend_bos", True), + ).to(self.cfg.device) + self.tokenizer.padding_side = _prev_side + forward_kwargs["attention_mask"] = attn_mask + # Adjust position_ids for left-padding so pad + # tokens don't consume real position embeddings. + position_ids = attn_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attn_mask == 0, 1) + forward_kwargs["position_ids"] = position_ids # Pass multimodal inputs only on the first step — the vision # encoder processes the image once, embedding it into the # token sequence. This includes pixel_values plus any extra @@ -2346,6 +2370,10 @@ def generate( [input_seq_pos], device=self.cfg.device ) forward_kwargs["cache_position"] = cache_position + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] logits = self( current_tokens[:, -1:], return_type="logits", @@ -2356,6 +2384,10 @@ def generate( if _hf_kv_cache is not None: # Cached step: pass only the last token + cache forward_kwargs["past_key_values"] = _hf_kv_cache + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] logits = self( current_tokens[:, -1:], return_type="logits",