From 22edbb55cbf2fe106752ce0025b37f6ed04cd05c Mon Sep 17 00:00:00 2001 From: Tuomas Oikarinen Date: Mon, 11 Aug 2025 11:08:52 -0700 Subject: [PATCH 1/7] fixed batching in generate --- transformer_lens/HookedTransformer.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 025b43793..2e55ef194 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -2095,7 +2095,7 @@ def generate( freq_penalty: float = 0.0, use_past_kv_cache: bool = True, prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, - padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, + padding_side: Optional[Literal["left", "right"]] = "left", return_type: Optional[str] = "input", verbose: bool = True, ) -> Union[ @@ -2139,9 +2139,9 @@ def generate( the BOS token to the input (applicable when input is a string). Defaults to None, implying usage of self.cfg.default_prepend_bos (default is True unless specified otherwise). Pass True or False to override the default. - padding_side (Union[Literal["left", "right"], None], optional): Overrides - self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple - strings of different lengths. + padding_side (Union[Literal["left", "right"], None], optional): Specifies which side to + pad when tokenizing multiple strings of different lengths. Defaults to left for + correct generation behavior. If None uses self.tokenizer.padding_side. return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the input was ('input'). @@ -2240,7 +2240,11 @@ def generate( for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): pos_offset = self.get_pos_offset(past_kv_cache, batch_size) - tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int) + if len(sampled_tokens_list) > 0: + sampled_tokens = torch.cat(sampled_tokens_list, dim=1) + tokens = torch.cat((input_tokens, sampled_tokens), dim=1) + else: + tokens = input_tokens attention_mask = utils.get_attention_mask( self.tokenizer, tokens, False if prepend_bos is None else prepend_bos ).to(device) @@ -2267,6 +2271,7 @@ def generate( past_kv_cache=past_kv_cache, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) else: logits = self.forward( @@ -2277,6 +2282,7 @@ def generate( past_kv_cache=past_kv_cache, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) else: # We input the entire sequence, as a [batch, pos] tensor, since we aren't using @@ -2288,6 +2294,7 @@ def generate( padding_side=padding_side, start_at_layer=start_at_layer, shortformer_pos_embed=shortformer_pos_embed, + attention_mask=attention_mask, ) final_logits = logits[:, -1, :] From 11c85b520e9767033c611489fdb64932374c679a Mon Sep 17 00:00:00 2001 From: tuomaso Date: Mon, 11 Aug 2025 11:54:32 -0700 Subject: [PATCH 2/7] added test case --- tests/unit/test_generate_batch.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 tests/unit/test_generate_batch.py diff --git a/tests/unit/test_generate_batch.py b/tests/unit/test_generate_batch.py new file mode 100644 index 000000000..a17b6ec59 --- /dev/null +++ b/tests/unit/test_generate_batch.py @@ -0,0 +1,16 @@ +from transformer_lens import HookedTransformer + +def test_generate_batch(): + """ + Test that batched and individual prompt generation produce the same outputs. + """ + model = HookedTransformer.from_pretrained("gpt2") + input_prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"] + orig_outputs = [] + for prompt in input_prompts: + out = model.generate(prompt, verbose=False, do_sample=False) + orig_outputs.append(out) + + batched_outputs = model.generate(input_prompts, verbose=False, do_sample=False) + for i in range(len(orig_outputs)): + assert orig_outputs[i] == batched_outputs[i] \ No newline at end of file From 1aa008691defed6e94432e056cb59d9d74eec0f3 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 21 Apr 2026 11:20:31 -0500 Subject: [PATCH 3/7] Move & improve tests --- tests/acceptance/conftest.py | 15 ++++++++++++ tests/acceptance/test_generate_batch.py | 32 +++++++++++++++++++++++++ tests/unit/test_generate_batch.py | 16 ------------- 3 files changed, 47 insertions(+), 16 deletions(-) create mode 100644 tests/acceptance/conftest.py create mode 100644 tests/acceptance/test_generate_batch.py delete mode 100644 tests/unit/test_generate_batch.py diff --git a/tests/acceptance/conftest.py b/tests/acceptance/conftest.py new file mode 100644 index 000000000..704405562 --- /dev/null +++ b/tests/acceptance/conftest.py @@ -0,0 +1,15 @@ +"""Shared fixtures for acceptance tests. + +Session-scoped fixtures avoid redundant model loads across test files. +All models used here must be in the CI cache (see .github/workflows/checks.yml). +""" + +import pytest + +from transformer_lens import HookedTransformer + + +@pytest.fixture(scope="session") +def gpt2_model(): + """Session-scoped HookedTransformer gpt2 with default weight processing.""" + return HookedTransformer.from_pretrained("gpt2", device="cpu") diff --git a/tests/acceptance/test_generate_batch.py b/tests/acceptance/test_generate_batch.py new file mode 100644 index 000000000..549669861 --- /dev/null +++ b/tests/acceptance/test_generate_batch.py @@ -0,0 +1,32 @@ +"""Tests that batched HookedTransformer generation matches individual generation.""" + + +def test_ht_generate_batch_matches_individual(gpt2_model): + """Batched generate() should match one-by-one generate() for left-padded inputs.""" + prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"] + individual_outputs = [ + gpt2_model.generate(p, verbose=False, do_sample=False) for p in prompts + ] + + batched_outputs = gpt2_model.generate(prompts, verbose=False, do_sample=False) + for i, prompt in enumerate(prompts): + assert individual_outputs[i] == batched_outputs[i], ( + f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}" + ) + + +def test_ht_generate_batch_without_kv_cache(gpt2_model): + """Same test with use_past_kv_cache=False.""" + prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"] + individual_outputs = [ + gpt2_model.generate(p, verbose=False, do_sample=False, use_past_kv_cache=False) + for p in prompts + ] + + batched_outputs = gpt2_model.generate( + prompts, verbose=False, do_sample=False, use_past_kv_cache=False + ) + for i, prompt in enumerate(prompts): + assert individual_outputs[i] == batched_outputs[i], ( + f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}" + ) diff --git a/tests/unit/test_generate_batch.py b/tests/unit/test_generate_batch.py deleted file mode 100644 index a17b6ec59..000000000 --- a/tests/unit/test_generate_batch.py +++ /dev/null @@ -1,16 +0,0 @@ -from transformer_lens import HookedTransformer - -def test_generate_batch(): - """ - Test that batched and individual prompt generation produce the same outputs. - """ - model = HookedTransformer.from_pretrained("gpt2") - input_prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"] - orig_outputs = [] - for prompt in input_prompts: - out = model.generate(prompt, verbose=False, do_sample=False) - orig_outputs.append(out) - - batched_outputs = model.generate(input_prompts, verbose=False, do_sample=False) - for i in range(len(orig_outputs)): - assert orig_outputs[i] == batched_outputs[i] \ No newline at end of file From 481040ec9ecb28b5d38ccd9acc9f32d60bf0e94f Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 21 Apr 2026 11:22:43 -0500 Subject: [PATCH 4/7] make check format and mypy --- tests/acceptance/test_generate_batch.py | 16 +++++++--------- transformer_lens/HookedTransformer.py | 6 +++--- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/acceptance/test_generate_batch.py b/tests/acceptance/test_generate_batch.py index 549669861..8b333d5f7 100644 --- a/tests/acceptance/test_generate_batch.py +++ b/tests/acceptance/test_generate_batch.py @@ -4,15 +4,13 @@ def test_ht_generate_batch_matches_individual(gpt2_model): """Batched generate() should match one-by-one generate() for left-padded inputs.""" prompts = ["Hello, my dog is cute", "This is a much longer text. Hello, my cat is cute"] - individual_outputs = [ - gpt2_model.generate(p, verbose=False, do_sample=False) for p in prompts - ] + individual_outputs = [gpt2_model.generate(p, verbose=False, do_sample=False) for p in prompts] batched_outputs = gpt2_model.generate(prompts, verbose=False, do_sample=False) for i, prompt in enumerate(prompts): - assert individual_outputs[i] == batched_outputs[i], ( - f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}" - ) + assert ( + individual_outputs[i] == batched_outputs[i] + ), f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}" def test_ht_generate_batch_without_kv_cache(gpt2_model): @@ -27,6 +25,6 @@ def test_ht_generate_batch_without_kv_cache(gpt2_model): prompts, verbose=False, do_sample=False, use_past_kv_cache=False ) for i, prompt in enumerate(prompts): - assert individual_outputs[i] == batched_outputs[i], ( - f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}" - ) + assert ( + individual_outputs[i] == batched_outputs[i] + ), f"Prompt {i} mismatch:\n individual: {individual_outputs[i]}\n batched: {batched_outputs[i]}" diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 11b6b56b0..1f772b1aa 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1935,9 +1935,9 @@ def generate( the BOS token to the input (applicable when input is a string). Defaults to None, implying usage of self.cfg.default_prepend_bos (default is True unless specified otherwise). Pass True or False to override the default. - padding_side (Union[Literal["left", "right"], None], optional): Specifies which side to - pad when tokenizing multiple strings of different lengths. Defaults to left for - correct generation behavior. If None uses self.tokenizer.padding_side. + padding_side (Union[Literal["left", "right"], None], optional): Specifies which side to + pad when tokenizing multiple strings of different lengths. Defaults to left for + correct generation behavior. If None uses self.tokenizer.padding_side. return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the input was ('input'). From d780c870aaf163689897f36426a54065cccbf272 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 21 Apr 2026 11:34:21 -0500 Subject: [PATCH 5/7] fix mypy errors --- transformer_lens/HookedTransformer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 1f772b1aa..2754e9a12 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -2062,14 +2062,16 @@ def generate( for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): pos_offset = self.get_pos_offset(past_kv_cache, batch_size) - if len(sampled_tokens_list) > 0: - sampled_tokens = torch.cat(sampled_tokens_list, dim=1) - tokens = torch.cat((input_tokens, sampled_tokens), dim=1) - else: - tokens = input_tokens - attention_mask = utils.get_attention_mask( - self.tokenizer, tokens, False if prepend_bos is None else prepend_bos - ).to(device) + attention_mask: Optional[torch.Tensor] = None + if input_tokens is not None: + if len(sampled_tokens_list) > 0: + sampled_tokens = torch.cat(sampled_tokens_list, dim=1) + tokens = torch.cat((input_tokens, sampled_tokens), dim=1) + else: + tokens = input_tokens + attention_mask = utils.get_attention_mask( + self.tokenizer, tokens, False if prepend_bos is None else prepend_bos + ).to(device) residual, shortformer_pos_embed = self.get_residual( embeds, pos_offset, From 2f6dc0cc13be02207d73b32d2d3fe9cff395cb17 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 21 Apr 2026 12:17:30 -0500 Subject: [PATCH 6/7] Stop jaxtyping failures --- tests/acceptance/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/acceptance/conftest.py b/tests/acceptance/conftest.py index 704405562..50ad394a6 100644 --- a/tests/acceptance/conftest.py +++ b/tests/acceptance/conftest.py @@ -6,10 +6,10 @@ import pytest -from transformer_lens import HookedTransformer - @pytest.fixture(scope="session") def gpt2_model(): """Session-scoped HookedTransformer gpt2 with default weight processing.""" + from transformer_lens import HookedTransformer + return HookedTransformer.from_pretrained("gpt2", device="cpu") From 29a544a367192e96e7f6deaf65adb50766dd7d5c Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 21 Apr 2026 13:54:34 -0500 Subject: [PATCH 7/7] Updated to also fix TransformerBridge for the same issue --- transformer_lens/HookedTransformer.py | 66 +++++++++++++++++++------ transformer_lens/model_bridge/bridge.py | 64 ++++++++++++++++++------ 2 files changed, 100 insertions(+), 30 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 2754e9a12..eaff70094 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1887,7 +1887,7 @@ def generate( freq_penalty: float = 0.0, use_past_kv_cache: bool = True, prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, - padding_side: Optional[Literal["left", "right"]] = "left", + padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, return_type: Optional[str] = "input", verbose: bool = True, **generation_kwargs, @@ -1935,9 +1935,10 @@ def generate( the BOS token to the input (applicable when input is a string). Defaults to None, implying usage of self.cfg.default_prepend_bos (default is True unless specified otherwise). Pass True or False to override the default. - padding_side (Union[Literal["left", "right"], None], optional): Specifies which side to - pad when tokenizing multiple strings of different lengths. Defaults to left for - correct generation behavior. If None uses self.tokenizer.padding_side. + padding_side (Union[Literal["left", "right"], None], optional): Overrides + self.tokenizer.padding_side. Specifies which side to pad when tokenizing + multiple strings of different lengths. For batched list inputs, left-padding + is forced internally for correct generation behavior. return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the input was ('input'). @@ -1974,13 +1975,25 @@ def generate( else: return_type = "embeds" + # initial_attention_mask is always computed so that single-prompt and + # batched generation go through the same masked code path, producing + # consistent results for the same prompt regardless of batching. + initial_attention_mask: Optional[torch.Tensor] = None + _is_batched_list = isinstance(input, list) and len(input) > 1 + if isinstance(input, (str, list)): input_type = "str" - # If text, convert to tokens (batch_size=1) assert ( self.tokenizer is not None ), "Must provide a tokenizer if passing a string to the model" - input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) + if _is_batched_list: + # Force left-padding for batched generation so real tokens + # are flush-right and logits[:, -1, :] is always correct. + input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side="left") + else: + input = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + ) elif input.ndim == 2: input_type = "tokens" else: @@ -1988,6 +2001,27 @@ def generate( input_tokens = input if input_type in ["str", "tokens"] else None batch_size, ctx_length = input.shape[0], input.shape[1] + + # Compute initial attention mask. For batched inputs with padding, + # this correctly masks pad tokens. For single/unpadded inputs, this + # is all-ones which matches the no-mask code path but ensures both + # go through the same PosEmbed/attention logic for consistency. + if input_tokens is not None and self.tokenizer is not None: + _prepend_bos = ( + self.cfg.default_prepend_bos + if prepend_bos is USE_DEFAULT_VALUE + else (False if prepend_bos is None else prepend_bos) + ) + # Temporarily set padding_side="left" so get_attention_mask + # scans for leading pads (matching the left-padded tokens). + _orig_padding_side = self.tokenizer.padding_side + if _is_batched_list: + self.tokenizer.padding_side = "left" + initial_attention_mask = utils.get_attention_mask( + self.tokenizer, input_tokens, _prepend_bos + ) + if _is_batched_list: + self.tokenizer.padding_side = _orig_padding_side device = get_device_for_block_index(0, self.cfg) input = input.to(device) if use_past_kv_cache: @@ -2062,16 +2096,20 @@ def generate( for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): pos_offset = self.get_pos_offset(past_kv_cache, batch_size) + # Extend the initial attention mask with 1s for generated tokens. attention_mask: Optional[torch.Tensor] = None - if input_tokens is not None: - if len(sampled_tokens_list) > 0: - sampled_tokens = torch.cat(sampled_tokens_list, dim=1) - tokens = torch.cat((input_tokens, sampled_tokens), dim=1) + if initial_attention_mask is not None: + n_new = len(sampled_tokens_list) + if n_new > 0: + ones = torch.ones( + batch_size, + n_new, + dtype=initial_attention_mask.dtype, + device=device, + ) + attention_mask = torch.cat([initial_attention_mask.to(device), ones], dim=1) else: - tokens = input_tokens - attention_mask = utils.get_attention_mask( - self.tokenizer, tokens, False if prepend_bos is None else prepend_bos - ).to(device) + attention_mask = initial_attention_mask.to(device) residual, shortformer_pos_embed = self.get_residual( embeds, pos_offset, diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index a07a4f330..93778e014 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -2121,9 +2121,9 @@ def generate( prepend_bos: Accepted for API compatibility but not applied during generation. The HF model expects tokens in its native format (tokenizer defaults). Overriding BOS can silently degrade generation quality. - padding_side: Accepted for API compatibility but not applied during generation. - The generation loop always extends tokens to the right, so overriding - initial padding_side creates inconsistent token layout. + padding_side: Which side to pad when tokenizing multiple strings of different + lengths. For batched list inputs, left-padding is forced internally for + correct generation behavior. Defaults to None (tokenizer default). return_type: The type of output to return - 'input', 'str', or 'tokens' verbose: Not used in Bridge (kept for API compatibility) output_logits: If True, return a ModelOutput with sequences and logits tuple @@ -2135,10 +2135,9 @@ def generate( Generated sequence as string, list of strings, or tensor depending on input type and return_type. If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. """ - # prepend_bos and padding_side are intentionally not applied during generation. + # prepend_bos is intentionally not applied during generation. # The HF model expects tokens in its native format. Overriding BOS can silently - # degrade quality, and overriding padding_side conflicts with the generation loop - # which always extends tokens to the right. + # degrade quality. if prepend_bos is not None: import warnings @@ -2149,27 +2148,28 @@ def generate( "resulting tensor to generate().", stacklevel=2, ) - if padding_side is not None: - import warnings - - warnings.warn( - "padding_side is ignored during TransformerBridge.generate(). " - "The generation loop extends tokens to the right regardless of initial " - "padding. To control padding, tokenize with to_tokens(padding_side=...) " - "and pass the resulting tensor to generate().", - stacklevel=2, - ) + # padding_side is handled internally: for batched list inputs, left-padding + # is forced to ensure correct generation. See _is_batched_list logic below. # Stateful dispatch is decided after input parsing so we can fall back # to hf_generate() for input types the stateful loop doesn't handle. is_stateful_model = getattr(self.cfg, "is_stateful", False) + _is_batched_list = isinstance(input, list) and len(input) > 1 + _generate_from_embeds = False if isinstance(input, str): input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) input_type = "str" elif isinstance(input, list): + # Force left-padding for batched generation so real tokens are + # flush-right and logits[:, -1, :] is always the last real token. + if _is_batched_list: + _orig_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + if _is_batched_list: + self.tokenizer.padding_side = _orig_padding_side input_type = "list" elif isinstance(input, torch.Tensor) and input.is_floating_point(): # inputs_embeds: pre-computed embeddings (e.g., from multimodal models) @@ -2307,6 +2307,30 @@ def generate( ) else: forward_kwargs: Dict[str, Any] = {} + # Compute attention mask and position_ids for batched + # inputs with padding. HF models default to all-ones + # when no mask is given, which ignores padding tokens. + if ( + _is_batched_list + and self.tokenizer is not None + and self.tokenizer.pad_token_id is not None + ): + # Temp-swap to "left" so get_attention_mask scans + # for leading pads (matching the left-padded tokens). + _prev_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + attn_mask = utils.get_attention_mask( + self.tokenizer, + current_tokens, + prepend_bos=getattr(self.cfg, "default_prepend_bos", True), + ).to(self.cfg.device) + self.tokenizer.padding_side = _prev_side + forward_kwargs["attention_mask"] = attn_mask + # Adjust position_ids for left-padding so pad + # tokens don't consume real position embeddings. + position_ids = attn_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attn_mask == 0, 1) + forward_kwargs["position_ids"] = position_ids # Pass multimodal inputs only on the first step — the vision # encoder processes the image once, embedding it into the # token sequence. This includes pixel_values plus any extra @@ -2346,6 +2370,10 @@ def generate( [input_seq_pos], device=self.cfg.device ) forward_kwargs["cache_position"] = cache_position + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] logits = self( current_tokens[:, -1:], return_type="logits", @@ -2356,6 +2384,10 @@ def generate( if _hf_kv_cache is not None: # Cached step: pass only the last token + cache forward_kwargs["past_key_values"] = _hf_kv_cache + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] logits = self( current_tokens[:, -1:], return_type="logits",