From b7bce6974805f7fe4926d43e0bc083bca5cab38f Mon Sep 17 00:00:00 2001 From: anthonyduong Date: Wed, 9 Apr 2025 23:16:39 -0700 Subject: [PATCH 1/5] adds HookedTransformer.generate_stream() --- tests/acceptance/test_hooked_transformer.py | 38 ++++ transformer_lens/HookedTransformer.py | 233 ++++++++++++++++++++ 2 files changed, 271 insertions(+) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 3fd739f94..865827d03 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -195,6 +195,44 @@ def test_bloom_similarity_with_hf_model_with_kv_cache_activated(): assert output_tf == output_hf_str +def test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream(): + tf_model = HookedTransformer.from_pretrained( + "bigscience/bloom-560m", default_prepend_bos=False, device="cpu" + ) + + hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") + hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + + gen = tf_model.generate_stream( + text, + do_sample=False, + use_past_kv_cache=True, + verbose=False, + max_new_tokens=10, + max_tokens_per_yield=10, + ) + + # Exhaust the generator to capture its final return value. + while True: + try: + next(gen) + except StopIteration as e: + final_output = e.value + break + + hf_input_ids = hf_tokenizer(text, return_tensors="pt").input_ids + output_hf_tokens = hf_model.generate( + hf_input_ids, + do_sample=False, + max_new_tokens=10, + ) + output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True) + + assert ( + final_output == output_hf_str + ), f"\nStreaming output: {final_output}\nHF output: {output_hf_str}" + + def check_norm_folding( model_name, hf_model=None, diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index cf4c369ac..549a51dd5 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -11,6 +11,7 @@ import logging import os +from collections.abc import Generator from typing import ( Dict, List, @@ -2340,6 +2341,238 @@ def generate( else: return embeds + @torch.inference_mode() + def generate_stream( + self, + input: Union[str, Float[torch.Tensor, "batch pos"]] = "", + max_new_tokens: int = 10, + max_tokens_per_yield: int = 25, + stop_at_eos: bool = True, + eos_token_id: Optional[int] = None, + do_sample: bool = True, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: float = 1.0, + freq_penalty: float = 0.0, + use_past_kv_cache: bool = True, + prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, + padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, + return_type: Optional[str] = "input", + verbose: bool = True, + ) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]: + """Stream tokens from the Model as they are generated. + + Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached, + yielding batches of tokens progressively during generation rather than waiting for the entire + sequence to be generated. + + To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish + (by producing an EOT token), we keep running the model on the entire batch, but throw away + the output for a finished sequence and just keep adding EOTs to pad. + + This supports entering a single string, but not a list of strings - if the strings don't + tokenize to exactly the same length, this gets messy. If that functionality is needed, + convert them to a batch of tokens and input that instead. + + Args: + input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, + pos]) or a text string (this will be converted to a batch of tokens with batch size + 1). + max_new_tokens (int): Maximum number of tokens to generate. + max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding. + Controls how frequently the function yields tokens during generation. + stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. + eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end + of sentence. If None, use the tokenizer's eos_token_id - required if using + stop_at_eos. It's also possible to provide a list of token IDs (not just the + eos_token_id), in which case the generation will stop when any of them are output + (useful e.g. for stable_lm). + do_sample (bool): If True, sample from the model's output distribution. Otherwise, use + greedy search (take the max logit each time). + top_k (int): Number of tokens to sample from. If None, sample from all tokens. + top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, + we take the top tokens with cumulative probability >= top_p. + temperature (float): Temperature for sampling. Higher values will make the model more + random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is + sampling from a uniform distribution). + freq_penalty (float): Frequency penalty for sampling - how much to penalise previous + tokens. Higher values will make the model more random. + use_past_kv_cache (bool): If True, create and use cache to speed up generation. + prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend + the BOS token to the input (applicable when input is a string). Defaults to None, + implying usage of self.cfg.default_prepend_bos (default is True unless specified + otherwise). Pass True or False to override the default. + padding_side (Union[Literal["left", "right"], None], optional): Overrides + self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple + strings of different lengths. + return_type (Optional[str]): The type of the output to return - either a string (str), + a tensor of tokens (tensor) or whatever the format of the input was (input). + verbose (bool): If True, show tqdm progress bars for generation. + + Yields: + outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded + progressively during generation. Each yield contains accumulated tokens since the last + yield, up to max_tokens_per_yield. + """ + + with utils.LocallyOverridenDefaults( + self, prepend_bos=prepend_bos, padding_side=padding_side + ): + if type(input) == str: + # If text, convert to tokens (batch_size=1) + assert ( + self.tokenizer is not None + ), "Must provide a tokenizer if passing a string to the model" + tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) + else: + tokens = input + + if return_type == "input": + if type(input) == str: + return_type = "str" + else: + return_type = "tensor" + + assert isinstance(tokens, torch.Tensor) + batch_size, ctx_length = tokens.shape + device = devices.get_device_for_block_index(0, self.cfg) + tokens = tokens.to(device) + if use_past_kv_cache: + past_kv_cache = HookedTransformerKeyValueCache.init_cache( + self.cfg, self.cfg.device, batch_size + ) + else: + past_kv_cache = None + + stop_tokens: List[int] = [] + eos_token_for_padding = 0 + assert self.tokenizer is not None + if stop_at_eos: + tokenizer_has_eos_token = ( + self.tokenizer is not None and self.tokenizer.eos_token_id is not None + ) + if eos_token_id is None: + assert ( + tokenizer_has_eos_token + ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" + + eos_token_id = self.tokenizer.eos_token_id + + if isinstance(eos_token_id, int): + stop_tokens = [eos_token_id] + eos_token_for_padding = eos_token_id + else: + # eos_token_id is a Sequence (e.g. list or tuple) + stop_tokens = eos_token_id + eos_token_for_padding = ( + self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] + ) + + # An array to track which sequences in the batch have finished. + finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) + + # Currently nothing in HookedTransformer changes with eval, but this is here in case + # that changes in the future. + self.eval() + for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): + # While generating, we keep generating logits, throw away all but the final logits, + # and then use those logits to sample from the distribution We keep adding the + # sampled tokens to the end of tokens. + if use_past_kv_cache: + # We just take the final tokens, as a [batch, 1] tensor + if index > 0: + logits = self.forward( + tokens[:, -1:], + return_type="logits", + prepend_bos=prepend_bos, + padding_side=padding_side, + past_kv_cache=past_kv_cache, + ) + else: + logits = self.forward( + tokens, + return_type="logits", + prepend_bos=prepend_bos, + padding_side=padding_side, + past_kv_cache=past_kv_cache, + ) + else: + # We input the entire sequence, as a [batch, pos] tensor, since we aren't using + # the cache. + logits = self.forward( + tokens, + return_type="logits", + prepend_bos=prepend_bos, + padding_side=padding_side, + ) + final_logits = logits[:, -1, :] + + if do_sample: + sampled_tokens = utils.sample_logits( + final_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + tokens=tokens, + ).to(devices.get_device_for_block_index(0, self.cfg)) + else: + sampled_tokens = final_logits.argmax(-1).to( + devices.get_device_for_block_index(0, self.cfg) + ) + + if stop_at_eos: + # For all unfinished sequences, add on the next token. If a sequence was + # finished, throw away the generated token and add eos_token_for_padding + # instead. + sampled_tokens[finished_sequences] = eos_token_for_padding + finished_sequences.logical_or_( + torch.isin( + sampled_tokens.to(self.cfg.device), + torch.tensor(stop_tokens).to(self.cfg.device), + ) + ) + + new_tokens = sampled_tokens.unsqueeze(-1) + + # Accumulate tokens until we hit max_tokens_per_yield + if index == 0: + accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1) + tokens_since_last_yield = accumulated_tokens.shape[1] + else: + if accumulated_tokens is None: + accumulated_tokens = new_tokens + else: + accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) + tokens_since_last_yield += 1 + + if tokens_since_last_yield >= max_tokens_per_yield: + yield accumulated_tokens + tokens_since_last_yield = 0 + accumulated_tokens = None + + tokens = torch.cat([tokens, new_tokens], dim=-1) + + if stop_at_eos and finished_sequences.all(): + # Yield any remaining accumulated tokens before breaking + if accumulated_tokens is not None: + yield accumulated_tokens + break + + # Only yield remaining tokens if we didn't already yield them in the break case + if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()): + yield accumulated_tokens + + if return_type == "str": + if self.cfg.default_prepend_bos: + # If we prepended a BOS token, remove it when returning output. + return self.tokenizer.decode(tokens[0, 1:]) + else: + return self.tokenizer.decode(tokens[0]) + + else: + return tokens + # Give access to all weights as properties. @property def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: From 2fe136b324f04078fe87f43e83149775ee7fd082 Mon Sep 17 00:00:00 2001 From: anthonyduong Date: Tue, 24 Jun 2025 15:23:10 -0700 Subject: [PATCH 2/5] fixes mypy errors --- tests/acceptance/test_hooked_transformer.py | 14 ++++---------- transformer_lens/HookedEncoderDecoder.py | 4 ++-- transformer_lens/HookedTransformer.py | 14 ++++---------- 3 files changed, 10 insertions(+), 22 deletions(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index c4ebff123..4f7b792dd 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -241,22 +241,16 @@ def test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream(): hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") - gen = tf_model.generate_stream( + final_output = "" + for result in tf_model.generate_stream( text, do_sample=False, use_past_kv_cache=True, verbose=False, max_new_tokens=10, max_tokens_per_yield=10, - ) - - # Exhaust the generator to capture its final return value. - while True: - try: - next(gen) - except StopIteration as e: - final_output = e.value - break + ): + final_output += tf_model.to_string(result[0]) hf_input_ids = hf_tokenizer(text, return_tensors="pt").input_ids output_hf_tokens = hf_model.generate( diff --git a/transformer_lens/HookedEncoderDecoder.py b/transformer_lens/HookedEncoderDecoder.py index 5622c8db9..d0e0e751b 100644 --- a/transformer_lens/HookedEncoderDecoder.py +++ b/transformer_lens/HookedEncoderDecoder.py @@ -486,13 +486,13 @@ def generate( else: return decoder_input - @overload + @overload # type: ignore[overload-overlap] def run_with_cache( self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: ... - @overload + @overload # type: ignore[overload-overlap] def run_with_cache( self, *model_args: Any, return_cache_object: Literal[False] = False, **kwargs: Any ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 737f83584..481154436 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -2438,6 +2438,7 @@ def generate_stream( ), "Must provide a tokenizer if passing a string to the model" tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) else: + assert isinstance(input, torch.Tensor), "Input must be a tensor when not a string" tokens = input if return_type == "input": @@ -2484,6 +2485,9 @@ def generate_stream( # An array to track which sequences in the batch have finished. finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) + accumulated_tokens: Optional[torch.Tensor] = None + tokens_since_last_yield = 0 + # Currently nothing in HookedTransformer changes with eval, but this is here in case # that changes in the future. self.eval() @@ -2576,16 +2580,6 @@ def generate_stream( if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()): yield accumulated_tokens - if return_type == "str": - if self.cfg.default_prepend_bos: - # If we prepended a BOS token, remove it when returning output. - return self.tokenizer.decode(tokens[0, 1:]) - else: - return self.tokenizer.decode(tokens[0]) - - else: - return tokens - # Give access to all weights as properties. @property def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: From 8eecd32bfd3634aaa0ade4ad60103f379fa6bc15 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 21 Apr 2026 14:04:56 -0500 Subject: [PATCH 3/5] Adjusted for TransformerLens 3 changes --- transformer_lens/HookedTransformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index bb4c781e8..6f37bc219 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -2299,10 +2299,10 @@ def generate_stream( assert isinstance(tokens, torch.Tensor) batch_size, ctx_length = tokens.shape - device = devices.get_device_for_block_index(0, self.cfg) + device = get_device_for_block_index(0, self.cfg) tokens = tokens.to(device) if use_past_kv_cache: - past_kv_cache = HookedTransformerKeyValueCache.init_cache( + past_kv_cache = TransformerLensKeyValueCache.init_cache( self.cfg, self.cfg.device, batch_size ) else: @@ -2382,10 +2382,10 @@ def generate_stream( temperature=temperature, freq_penalty=freq_penalty, tokens=tokens, - ).to(devices.get_device_for_block_index(0, self.cfg)) + ).to(get_device_for_block_index(0, self.cfg)) else: sampled_tokens = final_logits.argmax(-1).to( - devices.get_device_for_block_index(0, self.cfg) + get_device_for_block_index(0, self.cfg) ) if stop_at_eos: From 9e08ecf3a1960ff90059c1be55c48c2c3639652f Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 21 Apr 2026 14:33:43 -0500 Subject: [PATCH 4/5] Initial bridge generate stream --- .../model_bridge/test_generate_stream.py | 98 ++++ transformer_lens/model_bridge/bridge.py | 537 +++++++++++++----- 2 files changed, 485 insertions(+), 150 deletions(-) create mode 100644 tests/acceptance/model_bridge/test_generate_stream.py diff --git a/tests/acceptance/model_bridge/test_generate_stream.py b/tests/acceptance/model_bridge/test_generate_stream.py new file mode 100644 index 000000000..e6bd1a19f --- /dev/null +++ b/tests/acceptance/model_bridge/test_generate_stream.py @@ -0,0 +1,98 @@ +"""Tests for TransformerBridge.generate_stream().""" + +import torch + + +def test_stream_matches_generate(gpt2_bridge): + """Concatenated stream output should match generate() for the same prompt.""" + prompt = "The future of AI" + expected = gpt2_bridge.generate( + prompt, max_new_tokens=10, do_sample=False, verbose=False + ) + + # Collect all streamed chunks + chunks = list( + gpt2_bridge.generate_stream( + prompt, + max_new_tokens=10, + max_tokens_per_yield=3, + do_sample=False, + verbose=False, + ) + ) + assert len(chunks) >= 1 + + # Reconstruct: first chunk has input+tokens, subsequent have only new tokens + full_tokens = torch.cat(chunks, dim=-1) if len(chunks) > 1 else chunks[0] + # The first chunk includes input tokens, so just take the last chunk's end + # Actually, each chunk is independent — first has input+new, rest have only new + # So concatenating all gives input + all new tokens (with input repeated). + # Instead, compare decoded strings. + expected_text = gpt2_bridge.to_string(expected[0] if isinstance(expected, torch.Tensor) else gpt2_bridge.to_tokens(expected)[0]) + + # Decode last chunk which should have the most recent window of tokens + # Better: decode all chunks and concatenate + stream_texts = [] + for i, chunk in enumerate(chunks): + if i == 0: + stream_texts.append(gpt2_bridge.to_string(chunk[0])) + else: + stream_texts.append(gpt2_bridge.to_string(chunk[0])) + + # The first chunk has input+initial tokens, subsequent have only new tokens. + # The simplest comparison: the final full output should match. + # Reconstruct by taking the first chunk and appending decoded new tokens. + # Actually easier: just compare using the full token sequence. + # First chunk = input + first N tokens, subsequent = next tokens only. + all_tokens = chunks[0] + for chunk in chunks[1:]: + all_tokens = torch.cat([all_tokens, chunk], dim=-1) + + streamed_text = gpt2_bridge.to_string(all_tokens[0]) + assert expected_text == streamed_text, ( + f"Stream output mismatch:\n generate: {expected_text!r}\n stream: {streamed_text!r}" + ) + + +def test_stream_yields_progressively(gpt2_bridge): + """Multiple yields should occur with small max_tokens_per_yield.""" + chunks = list( + gpt2_bridge.generate_stream( + "Hello world", + max_new_tokens=10, + max_tokens_per_yield=3, + do_sample=False, + verbose=False, + ) + ) + assert len(chunks) > 1, f"Expected multiple yields, got {len(chunks)}" + + +def test_stream_single_prompt(gpt2_bridge): + """Basic single-string streaming should produce output.""" + results = list( + gpt2_bridge.generate_stream( + "Test", max_new_tokens=5, do_sample=False, verbose=False + ) + ) + assert len(results) >= 1 + assert results[0].shape[0] == 1 # batch=1 + assert results[0].shape[1] > 1 # has at least input + 1 generated token + + +def test_stream_stops_at_eos(gpt2_bridge): + """Streaming should respect stop_at_eos.""" + results = list( + gpt2_bridge.generate_stream( + "Test", + max_new_tokens=200, + max_tokens_per_yield=5, + stop_at_eos=True, + do_sample=False, + verbose=False, + ) + ) + # Count total generated tokens (first chunk has input, rest are new) + total_tokens = sum(r.shape[1] for r in results) + # Should have stopped well before 200 new tokens for a short prompt + assert total_tokens < 210 diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index a07a4f330..631af94c0 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -6,6 +6,7 @@ import logging import re import warnings +from collections.abc import Generator from contextlib import contextmanager from functools import lru_cache from typing import ( @@ -2077,6 +2078,190 @@ def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn): for hook_point, name in added_hooks: hook_point.remove_hooks() + def _generate_tokens( + self, + current_tokens: torch.Tensor, + input_tokens: torch.Tensor, + batch_size: int, + *, + max_new_tokens: int, + do_sample: bool, + top_k: Optional[int], + top_p: Optional[float], + temperature: float, + freq_penalty: float, + repetition_penalty: float, + stop_at_eos: bool, + stop_tokens: List[int], + eos_token_for_padding: int, + finished_sequences: torch.Tensor, + use_past_kv_cache: bool, + use_stateful_cache: bool, + mamba_cache: Any, + mamba_conv_kernel: int, + is_encoder_decoder: bool, + _is_batched_list: bool, + _generate_from_embeds: bool, + encoder_input: Optional[torch.Tensor], + decoder_tokens: Optional[torch.Tensor], + generated_token_ids: Optional[List[torch.Tensor]], + pixel_values: Optional[torch.Tensor], + multimodal_kwargs: Dict[str, Any], + verbose: bool, + ) -> Generator[Tuple[torch.Tensor, torch.Tensor, bool], None, None]: + """Core generation loop. Yields (sampled_tokens, final_logits, all_finished) per step. + + Owns the forward pass, sampling, EOS handling, token accumulation, and + KV cache management. Callers are responsible for try/finally cleanup of + ``_capture_hf_cache``. + """ + _hf_kv_cache = None + + for gen_step_idx in range(max_new_tokens): + with torch.no_grad(): + if is_encoder_decoder: + logits = self( + encoder_input, + return_type="logits", + decoder_input=decoder_tokens, + ) + else: + forward_kwargs: Dict[str, Any] = {} + # Compute attention mask and position_ids for batched + # inputs with padding. + if ( + _is_batched_list + and self.tokenizer is not None + and self.tokenizer.pad_token_id is not None + ): + _prev_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + attn_mask = utils.get_attention_mask( + self.tokenizer, + current_tokens, + prepend_bos=getattr(self.cfg, "default_prepend_bos", True), + ).to(self.cfg.device) + self.tokenizer.padding_side = _prev_side + forward_kwargs["attention_mask"] = attn_mask + position_ids = attn_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attn_mask == 0, 1) + forward_kwargs["position_ids"] = position_ids + if gen_step_idx == 0: + if pixel_values is not None: + forward_kwargs["pixel_values"] = pixel_values + if multimodal_kwargs: + forward_kwargs.update(multimodal_kwargs) + if use_stateful_cache: + forward_kwargs["cache_params"] = mamba_cache + forward_kwargs["use_cache"] = True + if gen_step_idx == 0: + cache_position = torch.arange( + 0, mamba_conv_kernel, device=self.cfg.device + ) + forward_kwargs["cache_position"] = cache_position + logits = self( + current_tokens, + return_type="logits", + **forward_kwargs, + ) + else: + input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1 + cache_position = torch.tensor([input_seq_pos], device=self.cfg.device) + forward_kwargs["cache_position"] = cache_position + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] + logits = self( + current_tokens[:, -1:], + return_type="logits", + **forward_kwargs, + ) + elif use_past_kv_cache: + forward_kwargs["use_cache"] = True + if _hf_kv_cache is not None: + forward_kwargs["past_key_values"] = _hf_kv_cache + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] + logits = self( + current_tokens[:, -1:], + return_type="logits", + **forward_kwargs, + ) + else: + logits = self( + current_tokens, + return_type="logits", + **forward_kwargs, + ) + else: + logits = self(current_tokens, return_type="logits", **forward_kwargs) + if use_past_kv_cache and hasattr(self, "_last_hf_cache"): + _hf_kv_cache = self._last_hf_cache or _hf_kv_cache + del self._last_hf_cache + final_logits = logits[:, -1, :] + + # Sample next token + penalty_tokens = ( + torch.stack(generated_token_ids, dim=1) + if _generate_from_embeds and generated_token_ids + else None + ) + if do_sample: + sampled_tokens = utils.sample_logits( + final_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + repetition_penalty=repetition_penalty, + tokens=penalty_tokens + if _generate_from_embeds + else (decoder_tokens if is_encoder_decoder else current_tokens), + ).to(self.cfg.device) + else: + sampled_tokens = utils.sample_logits( + final_logits, + temperature=0.0, + repetition_penalty=repetition_penalty, + tokens=penalty_tokens + if _generate_from_embeds + else (decoder_tokens if is_encoder_decoder else current_tokens), + ).to(self.cfg.device) + + # Handle EOS + if stop_at_eos: + sampled_tokens[finished_sequences] = eos_token_for_padding + finished_sequences.logical_or_( + torch.isin( + sampled_tokens.to(self.cfg.device), + torch.tensor(stop_tokens).to(self.cfg.device), + ) + ) + + # Update token sequences + if is_encoder_decoder: + assert decoder_tokens is not None + decoder_tokens = torch.cat([decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1) + elif _generate_from_embeds: + assert generated_token_ids is not None + generated_token_ids.append(sampled_tokens) + embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator] + assert embed_fn is not None + new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype) + current_tokens = torch.cat([current_tokens, new_embed], dim=1) + else: + current_tokens = torch.cat([current_tokens, sampled_tokens.unsqueeze(1)], dim=1) + + all_finished = bool(stop_at_eos and finished_sequences.all().item()) + + yield sampled_tokens, final_logits, all_finished + + if all_finished: + return + def generate( self, input: Union[str, List[str], torch.Tensor] = "", @@ -2296,156 +2481,41 @@ def generate( ) try: - for gen_step_idx in range(max_new_tokens): - # Get logits for next token - with torch.no_grad(): - if is_encoder_decoder: - logits = self( - encoder_input, - return_type="logits", - decoder_input=decoder_tokens, - ) - else: - forward_kwargs: Dict[str, Any] = {} - # Pass multimodal inputs only on the first step — the vision - # encoder processes the image once, embedding it into the - # token sequence. This includes pixel_values plus any extra - # processor outputs (e.g. image_sizes for LlavaNext). - if gen_step_idx == 0: - if pixel_values is not None: - forward_kwargs["pixel_values"] = pixel_values - if multimodal_kwargs: - forward_kwargs.update(multimodal_kwargs) - if use_stateful_cache: - # Prefill sends arange(conv_kernel) (which both - # Mamba-1's length check and Mamba-2's value check - # accept as "not decode"). Decode sends the input - # token's actual sequence position — a fixed value - # above conv_kernel-1 silently picks the wrong - # slot for short prompts (see - # test_greedy_matches_hf_across_prompt_lengths). - # conv1d hooks fire only on prefill; HF bypasses - # the conv1d module on decode (see DepthwiseConv1DBridge). - forward_kwargs["cache_params"] = mamba_cache - forward_kwargs["use_cache"] = True - if gen_step_idx == 0: - cache_position = torch.arange( - 0, mamba_conv_kernel, device=self.cfg.device - ) - forward_kwargs["cache_position"] = cache_position - logits = self( - current_tokens, - return_type="logits", - **forward_kwargs, - ) - else: - # Token generated at step N-1 lives at - # sequence position prompt_len + gen_step_idx - 1 - input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1 - cache_position = torch.tensor( - [input_seq_pos], device=self.cfg.device - ) - forward_kwargs["cache_position"] = cache_position - logits = self( - current_tokens[:, -1:], - return_type="logits", - **forward_kwargs, - ) - elif use_past_kv_cache: - forward_kwargs["use_cache"] = True - if _hf_kv_cache is not None: - # Cached step: pass only the last token + cache - forward_kwargs["past_key_values"] = _hf_kv_cache - logits = self( - current_tokens[:, -1:], - return_type="logits", - **forward_kwargs, - ) - else: - # Step 0: full sequence, cache gets populated - logits = self( - current_tokens, - return_type="logits", - **forward_kwargs, - ) - else: - # No cache: full sequence every step - logits = self(current_tokens, return_type="logits", **forward_kwargs) - # Capture HF cache from forward() for next step. - if use_past_kv_cache and hasattr(self, "_last_hf_cache"): - _hf_kv_cache = self._last_hf_cache or _hf_kv_cache - del self._last_hf_cache - final_logits = logits[:, -1, :] - - # Collect logits if requested - if logits_seq_list is not None: - logits_seq_list.append(final_logits.clone()) - - # Sample next token - # For inputs_embeds, we can't pass the embeddings to freq/rep penalty, - # so use the generated_token_ids for penalty tracking - penalty_tokens = ( - torch.stack(generated_token_ids, dim=1) - if _generate_from_embeds and generated_token_ids - else None - ) - if do_sample: - sampled_tokens = utils.sample_logits( - final_logits, - top_k=top_k, - top_p=top_p, - temperature=temperature, - freq_penalty=freq_penalty, - repetition_penalty=repetition_penalty, - tokens=penalty_tokens - if _generate_from_embeds - else (decoder_tokens if is_encoder_decoder else current_tokens), - ).to(self.cfg.device) - else: - sampled_tokens = utils.sample_logits( - final_logits, - temperature=0.0, - repetition_penalty=repetition_penalty, - tokens=penalty_tokens - if _generate_from_embeds - else (decoder_tokens if is_encoder_decoder else current_tokens), - ).to(self.cfg.device) - - sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) - - # Handle EOS tokens for finished sequences - if stop_at_eos: - sampled_tokens[finished_sequences] = eos_token_for_padding - finished_sequences.logical_or_( - torch.isin( - sampled_tokens.to(self.cfg.device), - torch.tensor(stop_tokens).to(self.cfg.device), - ) - ) - - # Append sampled token to current sequence - if is_encoder_decoder: - decoder_tokens = torch.cat( - [decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1 - ) - elif _generate_from_embeds: - # For inputs_embeds: get the embedding of the new token and append - generated_token_ids.append(sampled_tokens) - embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator] - assert embed_fn is not None - new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype) - current_tokens = torch.cat([current_tokens, new_embed], dim=1) - else: - current_tokens = torch.cat( - [current_tokens, sampled_tokens.unsqueeze(1)], dim=1 - ) - - # Early stopping if all sequences finished - if stop_at_eos and finished_sequences.all(): - break + for sampled_tokens, final_logits, all_finished in self._generate_tokens( + current_tokens, + input_tokens, + batch_size, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + repetition_penalty=repetition_penalty, + stop_at_eos=stop_at_eos, + stop_tokens=stop_tokens, + eos_token_for_padding=eos_token_for_padding, + finished_sequences=finished_sequences, + use_past_kv_cache=use_past_kv_cache, + use_stateful_cache=use_stateful_cache, + mamba_cache=mamba_cache, + mamba_conv_kernel=mamba_conv_kernel, + is_encoder_decoder=is_encoder_decoder, + _is_batched_list=False, + _generate_from_embeds=_generate_from_embeds, + encoder_input=encoder_input if is_encoder_decoder else None, + decoder_tokens=decoder_tokens if is_encoder_decoder else None, + generated_token_ids=generated_token_ids if _generate_from_embeds else None, + pixel_values=pixel_values, + multimodal_kwargs=multimodal_kwargs if multimodal_kwargs else {}, + verbose=verbose, + ): + sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) + if logits_seq_list is not None: + logits_seq_list.append(final_logits.clone()) + if all_finished: + break finally: - # Clean up generate-only state even if an exception occurs, - # so _capture_hf_cache doesn't leak into subsequent forward() calls. self._capture_hf_cache = False if hasattr(self, "_last_hf_cache"): del self._last_hf_cache @@ -2453,7 +2523,8 @@ def generate( # Concatenate all sampled tokens sampled_tokens = torch.cat(sampled_tokens_list, dim=1) if is_encoder_decoder: - output_tokens = decoder_tokens + # Reconstruct full decoder sequence: start token + generated tokens + output_tokens = torch.cat([decoder_tokens[:, :1], sampled_tokens], dim=1) elif _generate_from_embeds: # For inputs_embeds, we only have the generated token IDs (no input token IDs) output_tokens = sampled_tokens @@ -2500,6 +2571,172 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ... else: # return_type == "tokens" return output_tokens + @torch.no_grad() + def generate_stream( + self, + input: Union[str, List[str], torch.Tensor] = "", + max_new_tokens: int = 10, + max_tokens_per_yield: int = 25, + stop_at_eos: bool = True, + eos_token_id: Optional[int] = None, + do_sample: bool = True, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: float = 1.0, + freq_penalty: float = 0.0, + repetition_penalty: float = 1.0, + use_past_kv_cache: bool = True, + prepend_bos: Optional[bool] = None, + padding_side: Optional[str] = None, + return_type: Optional[str] = "input", + verbose: bool = True, + ) -> Generator[Union[torch.Tensor, str], None, None]: + """Stream tokens from the model as they are generated. + + Yields batches of tokens progressively during generation rather than + waiting for the entire sequence. Uses the same core loop as generate(). + + Args: + input: Text string, list of strings, or tensor of tokens. + max_new_tokens: Maximum number of tokens to generate. + max_tokens_per_yield: Yield accumulated tokens every this many steps. + stop_at_eos: If True, stop when eos_token is produced. + eos_token_id: Token ID(s) for end of sentence. Defaults to tokenizer's. + do_sample: If True, sample; otherwise greedy. + top_k: Top-k sampling. None means no filtering. + top_p: Nucleus sampling threshold. + temperature: Sampling temperature. + freq_penalty: Frequency penalty for previous tokens. + repetition_penalty: HF-style repetition penalty (>1.0 discourages repeats). + use_past_kv_cache: Use KV caching for faster generation. + prepend_bos: Not applied (API compatibility). See generate() docstring. + padding_side: Which side to pad for batched list inputs. Left-padding + is forced internally for batched generation. + return_type: 'input' (match input type), 'str', or 'tokens'. + verbose: Show progress bar. + + Yields: + Token tensors [batch, seq_len] or strings, accumulated up to + max_tokens_per_yield tokens between yields. First yield includes + the input tokens; subsequent yields contain only new tokens. + """ + # --- Input parsing (mirrors generate()) --- + _is_batched_list = isinstance(input, list) and len(input) > 1 + + if isinstance(input, str): + input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + input_type = "str" + elif isinstance(input, list): + if _is_batched_list: + _orig_ps = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + if _is_batched_list: + self.tokenizer.padding_side = _orig_ps + input_type = "list" + else: + input_tokens = input.to(self.cfg.device) + input_type = "tokens" + + if return_type == "input": + return_type = "str" if input_type in ["str", "list"] else "tokens" + + batch_size = input_tokens.shape[0] + + # --- EOS setup --- + stop_tokens: List[int] = [] + eos_token_for_padding = 0 + if stop_at_eos: + if eos_token_id is None: + assert ( + self.tokenizer.eos_token_id is not None + ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id" + eos_token_id = self.tokenizer.eos_token_id + if isinstance(eos_token_id, int): + stop_tokens = [eos_token_id] + eos_token_for_padding = eos_token_id + else: + stop_tokens = list(eos_token_id) + eos_token_for_padding = ( + self.tokenizer.eos_token_id + if self.tokenizer.eos_token_id is not None + else eos_token_id[0] + ) + + finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) + + # --- Cache setup --- + if use_past_kv_cache: + self._capture_hf_cache = True + + current_tokens = input_tokens.clone() + + # --- Streaming loop --- + accumulated_tokens: Optional[torch.Tensor] = None + tokens_since_last_yield = 0 + + try: + for step_idx, (sampled_tokens, _, all_finished) in enumerate( + self._generate_tokens( + current_tokens, + input_tokens, + batch_size, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + repetition_penalty=repetition_penalty, + stop_at_eos=stop_at_eos, + stop_tokens=stop_tokens, + eos_token_for_padding=eos_token_for_padding, + finished_sequences=finished_sequences, + use_past_kv_cache=use_past_kv_cache, + use_stateful_cache=False, + mamba_cache=None, + mamba_conv_kernel=0, + is_encoder_decoder=False, + _is_batched_list=_is_batched_list, + _generate_from_embeds=False, + encoder_input=None, + decoder_tokens=None, + generated_token_ids=None, + pixel_values=None, + multimodal_kwargs={}, + verbose=verbose, + ) + ): + new_tokens = sampled_tokens.unsqueeze(-1) + + if step_idx == 0: + accumulated_tokens = torch.cat([input_tokens, new_tokens], dim=-1) + tokens_since_last_yield = accumulated_tokens.shape[1] + else: + if accumulated_tokens is None: + accumulated_tokens = new_tokens + else: + accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) + tokens_since_last_yield += 1 + + if tokens_since_last_yield >= max_tokens_per_yield: + yield accumulated_tokens + tokens_since_last_yield = 0 + accumulated_tokens = None + + if all_finished: + if accumulated_tokens is not None: + yield accumulated_tokens + break + + # Yield remainder after loop completes without break + if accumulated_tokens is not None: + yield accumulated_tokens + finally: + self._capture_hf_cache = False + if hasattr(self, "_last_hf_cache"): + del self._last_hf_cache + def hf_generate( self, input: str | list[str] | torch.Tensor = "", From 4230315c9ca4ee73218fe23bea5dddd8173e52cb Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 21 Apr 2026 14:46:19 -0500 Subject: [PATCH 5/5] TransformerBridge Generate Stream --- .../model_bridge/test_generate_stream.py | 65 +++++++++---------- transformer_lens/model_bridge/bridge.py | 34 ++++++++-- 2 files changed, 59 insertions(+), 40 deletions(-) diff --git a/tests/acceptance/model_bridge/test_generate_stream.py b/tests/acceptance/model_bridge/test_generate_stream.py index e6bd1a19f..70e7a456a 100644 --- a/tests/acceptance/model_bridge/test_generate_stream.py +++ b/tests/acceptance/model_bridge/test_generate_stream.py @@ -6,11 +6,11 @@ def test_stream_matches_generate(gpt2_bridge): """Concatenated stream output should match generate() for the same prompt.""" prompt = "The future of AI" - expected = gpt2_bridge.generate( - prompt, max_new_tokens=10, do_sample=False, verbose=False - ) + # Get generate() output as string + expected_text = gpt2_bridge.generate(prompt, max_new_tokens=10, do_sample=False, verbose=False) + assert isinstance(expected_text, str) - # Collect all streamed chunks + # Stream as tokens so we can concatenate and compare chunks = list( gpt2_bridge.generate_stream( prompt, @@ -18,40 +18,20 @@ def test_stream_matches_generate(gpt2_bridge): max_tokens_per_yield=3, do_sample=False, verbose=False, + return_type="tokens", ) ) assert len(chunks) >= 1 - # Reconstruct: first chunk has input+tokens, subsequent have only new tokens - full_tokens = torch.cat(chunks, dim=-1) if len(chunks) > 1 else chunks[0] - # The first chunk includes input tokens, so just take the last chunk's end - # Actually, each chunk is independent — first has input+new, rest have only new - # So concatenating all gives input + all new tokens (with input repeated). - # Instead, compare decoded strings. - expected_text = gpt2_bridge.to_string(expected[0] if isinstance(expected, torch.Tensor) else gpt2_bridge.to_tokens(expected)[0]) - - # Decode last chunk which should have the most recent window of tokens - # Better: decode all chunks and concatenate - stream_texts = [] - for i, chunk in enumerate(chunks): - if i == 0: - stream_texts.append(gpt2_bridge.to_string(chunk[0])) - else: - stream_texts.append(gpt2_bridge.to_string(chunk[0])) - - # The first chunk has input+initial tokens, subsequent have only new tokens. - # The simplest comparison: the final full output should match. - # Reconstruct by taking the first chunk and appending decoded new tokens. - # Actually easier: just compare using the full token sequence. - # First chunk = input + first N tokens, subsequent = next tokens only. + # First chunk = input + first tokens, subsequent = new tokens only. all_tokens = chunks[0] for chunk in chunks[1:]: all_tokens = torch.cat([all_tokens, chunk], dim=-1) - streamed_text = gpt2_bridge.to_string(all_tokens[0]) - assert expected_text == streamed_text, ( - f"Stream output mismatch:\n generate: {expected_text!r}\n stream: {streamed_text!r}" - ) + streamed_text = gpt2_bridge.tokenizer.decode(all_tokens[0], skip_special_tokens=True) + assert ( + expected_text == streamed_text + ), f"Stream output mismatch:\n generate: {expected_text!r}\n stream: {streamed_text!r}" def test_stream_yields_progressively(gpt2_bridge): @@ -63,6 +43,7 @@ def test_stream_yields_progressively(gpt2_bridge): max_tokens_per_yield=3, do_sample=False, verbose=False, + return_type="tokens", ) ) assert len(chunks) > 1, f"Expected multiple yields, got {len(chunks)}" @@ -72,7 +53,11 @@ def test_stream_single_prompt(gpt2_bridge): """Basic single-string streaming should produce output.""" results = list( gpt2_bridge.generate_stream( - "Test", max_new_tokens=5, do_sample=False, verbose=False + "Test", + max_new_tokens=5, + do_sample=False, + verbose=False, + return_type="tokens", ) ) assert len(results) >= 1 @@ -90,9 +75,23 @@ def test_stream_stops_at_eos(gpt2_bridge): stop_at_eos=True, do_sample=False, verbose=False, + return_type="tokens", ) ) - # Count total generated tokens (first chunk has input, rest are new) total_tokens = sum(r.shape[1] for r in results) - # Should have stopped well before 200 new tokens for a short prompt assert total_tokens < 210 + + +def test_stream_returns_strings(gpt2_bridge): + """With return_type='str', yields should be strings.""" + results = list( + gpt2_bridge.generate_stream( + "Hello", + max_new_tokens=5, + do_sample=False, + verbose=False, + return_type="str", + ) + ) + assert len(results) >= 1 + assert all(isinstance(r, str) for r in results) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 37975743a..84b64186c 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -27,6 +27,7 @@ import einops import numpy as np import torch +import tqdm from torch import nn from transformer_lens import utilities as utils @@ -2117,7 +2118,7 @@ def _generate_tokens( """ _hf_kv_cache = None - for gen_step_idx in range(max_new_tokens): + for gen_step_idx in tqdm.tqdm(range(max_new_tokens), disable=not verbose): with torch.no_grad(): if is_encoder_decoder: logits = self( @@ -2620,6 +2621,13 @@ def generate_stream( max_tokens_per_yield tokens between yields. First yield includes the input tokens; subsequent yields contain only new tokens. """ + if prepend_bos is not None: + warnings.warn( + "prepend_bos is ignored during TransformerBridge.generate_stream(). " + "The HF model expects tokens with the tokenizer's default BOS handling.", + stacklevel=2, + ) + # --- Input parsing (mirrors generate()) --- _is_batched_list = isinstance(input, list) and len(input) > 1 @@ -2630,9 +2638,11 @@ def generate_stream( if _is_batched_list: _orig_ps = self.tokenizer.padding_side self.tokenizer.padding_side = "left" - input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) - if _is_batched_list: - self.tokenizer.padding_side = _orig_ps + try: + input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + finally: + if _is_batched_list: + self.tokenizer.padding_side = _orig_ps input_type = "list" else: input_tokens = input.to(self.cfg.device) @@ -2672,9 +2682,19 @@ def generate_stream( current_tokens = input_tokens.clone() # --- Streaming loop --- + # All yields are token tensors [batch, seq_len]. Each yield contains + # only the newly generated tokens since the previous yield (the first + # yield additionally prepends the input tokens for context). accumulated_tokens: Optional[torch.Tensor] = None tokens_since_last_yield = 0 + def _maybe_decode( + tokens: torch.Tensor, + ) -> Union[torch.Tensor, str]: + if return_type == "str": + return self.tokenizer.decode(tokens[0], skip_special_tokens=True) + return tokens + try: for step_idx, (sampled_tokens, _, all_finished) in enumerate( self._generate_tokens( @@ -2720,18 +2740,18 @@ def generate_stream( tokens_since_last_yield += 1 if tokens_since_last_yield >= max_tokens_per_yield: - yield accumulated_tokens + yield _maybe_decode(accumulated_tokens) tokens_since_last_yield = 0 accumulated_tokens = None if all_finished: if accumulated_tokens is not None: - yield accumulated_tokens + yield _maybe_decode(accumulated_tokens) break # Yield remainder after loop completes without break if accumulated_tokens is not None: - yield accumulated_tokens + yield _maybe_decode(accumulated_tokens) finally: self._capture_hf_cache = False if hasattr(self, "_last_hf_cache"):