diff --git a/tests/acceptance/model_bridge/test_generate_stream.py b/tests/acceptance/model_bridge/test_generate_stream.py new file mode 100644 index 000000000..70e7a456a --- /dev/null +++ b/tests/acceptance/model_bridge/test_generate_stream.py @@ -0,0 +1,97 @@ +"""Tests for TransformerBridge.generate_stream().""" + +import torch + + +def test_stream_matches_generate(gpt2_bridge): + """Concatenated stream output should match generate() for the same prompt.""" + prompt = "The future of AI" + # Get generate() output as string + expected_text = gpt2_bridge.generate(prompt, max_new_tokens=10, do_sample=False, verbose=False) + assert isinstance(expected_text, str) + + # Stream as tokens so we can concatenate and compare + chunks = list( + gpt2_bridge.generate_stream( + prompt, + max_new_tokens=10, + max_tokens_per_yield=3, + do_sample=False, + verbose=False, + return_type="tokens", + ) + ) + assert len(chunks) >= 1 + + # First chunk = input + first tokens, subsequent = new tokens only. + all_tokens = chunks[0] + for chunk in chunks[1:]: + all_tokens = torch.cat([all_tokens, chunk], dim=-1) + + streamed_text = gpt2_bridge.tokenizer.decode(all_tokens[0], skip_special_tokens=True) + assert ( + expected_text == streamed_text + ), f"Stream output mismatch:\n generate: {expected_text!r}\n stream: {streamed_text!r}" + + +def test_stream_yields_progressively(gpt2_bridge): + """Multiple yields should occur with small max_tokens_per_yield.""" + chunks = list( + gpt2_bridge.generate_stream( + "Hello world", + max_new_tokens=10, + max_tokens_per_yield=3, + do_sample=False, + verbose=False, + return_type="tokens", + ) + ) + assert len(chunks) > 1, f"Expected multiple yields, got {len(chunks)}" + + +def test_stream_single_prompt(gpt2_bridge): + """Basic single-string streaming should produce output.""" + results = list( + gpt2_bridge.generate_stream( + "Test", + max_new_tokens=5, + do_sample=False, + verbose=False, + return_type="tokens", + ) + ) + assert len(results) >= 1 + assert results[0].shape[0] == 1 # batch=1 + assert results[0].shape[1] > 1 # has at least input + 1 generated token + + +def test_stream_stops_at_eos(gpt2_bridge): + """Streaming should respect stop_at_eos.""" + results = list( + gpt2_bridge.generate_stream( + "Test", + max_new_tokens=200, + max_tokens_per_yield=5, + stop_at_eos=True, + do_sample=False, + verbose=False, + return_type="tokens", + ) + ) + total_tokens = sum(r.shape[1] for r in results) + assert total_tokens < 210 + + +def test_stream_returns_strings(gpt2_bridge): + """With return_type='str', yields should be strings.""" + results = list( + gpt2_bridge.generate_stream( + "Hello", + max_new_tokens=5, + do_sample=False, + verbose=False, + return_type="str", + ) + ) + assert len(results) >= 1 + assert all(isinstance(r, str) for r in results) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 7930ac08a..9e24fa9fc 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -6,6 +6,7 @@ import logging import re import warnings +from collections.abc import Generator from contextlib import contextmanager from functools import lru_cache from typing import ( @@ -26,6 +27,7 @@ import einops import numpy as np import torch +import tqdm from torch import nn from transformer_lens import utilities as utils @@ -2136,6 +2138,190 @@ def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn): for hook_point, name in added_hooks: hook_point.remove_hooks() + def _generate_tokens( + self, + current_tokens: torch.Tensor, + input_tokens: torch.Tensor, + batch_size: int, + *, + max_new_tokens: int, + do_sample: bool, + top_k: Optional[int], + top_p: Optional[float], + temperature: float, + freq_penalty: float, + repetition_penalty: float, + stop_at_eos: bool, + stop_tokens: List[int], + eos_token_for_padding: int, + finished_sequences: torch.Tensor, + use_past_kv_cache: bool, + use_stateful_cache: bool, + mamba_cache: Any, + mamba_conv_kernel: int, + is_encoder_decoder: bool, + _is_batched_list: bool, + _generate_from_embeds: bool, + encoder_input: Optional[torch.Tensor], + decoder_tokens: Optional[torch.Tensor], + generated_token_ids: Optional[List[torch.Tensor]], + pixel_values: Optional[torch.Tensor], + multimodal_kwargs: Dict[str, Any], + verbose: bool, + ) -> Generator[Tuple[torch.Tensor, torch.Tensor, bool], None, None]: + """Core generation loop. Yields (sampled_tokens, final_logits, all_finished) per step. + + Owns the forward pass, sampling, EOS handling, token accumulation, and + KV cache management. Callers are responsible for try/finally cleanup of + ``_capture_hf_cache``. + """ + _hf_kv_cache = None + + for gen_step_idx in tqdm.tqdm(range(max_new_tokens), disable=not verbose): + with torch.no_grad(): + if is_encoder_decoder: + logits = self( + encoder_input, + return_type="logits", + decoder_input=decoder_tokens, + ) + else: + forward_kwargs: Dict[str, Any] = {} + # Compute attention mask and position_ids for batched + # inputs with padding. + if ( + _is_batched_list + and self.tokenizer is not None + and self.tokenizer.pad_token_id is not None + ): + _prev_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + attn_mask = utils.get_attention_mask( + self.tokenizer, + current_tokens, + prepend_bos=getattr(self.cfg, "default_prepend_bos", True), + ).to(self.cfg.device) + self.tokenizer.padding_side = _prev_side + forward_kwargs["attention_mask"] = attn_mask + position_ids = attn_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attn_mask == 0, 1) + forward_kwargs["position_ids"] = position_ids + if gen_step_idx == 0: + if pixel_values is not None: + forward_kwargs["pixel_values"] = pixel_values + if multimodal_kwargs: + forward_kwargs.update(multimodal_kwargs) + if use_stateful_cache: + forward_kwargs["cache_params"] = mamba_cache + forward_kwargs["use_cache"] = True + if gen_step_idx == 0: + cache_position = torch.arange( + 0, mamba_conv_kernel, device=self.cfg.device + ) + forward_kwargs["cache_position"] = cache_position + logits = self( + current_tokens, + return_type="logits", + **forward_kwargs, + ) + else: + input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1 + cache_position = torch.tensor([input_seq_pos], device=self.cfg.device) + forward_kwargs["cache_position"] = cache_position + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] + logits = self( + current_tokens[:, -1:], + return_type="logits", + **forward_kwargs, + ) + elif use_past_kv_cache: + forward_kwargs["use_cache"] = True + if _hf_kv_cache is not None: + forward_kwargs["past_key_values"] = _hf_kv_cache + if "position_ids" in forward_kwargs: + forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ + :, -1: + ] + logits = self( + current_tokens[:, -1:], + return_type="logits", + **forward_kwargs, + ) + else: + logits = self( + current_tokens, + return_type="logits", + **forward_kwargs, + ) + else: + logits = self(current_tokens, return_type="logits", **forward_kwargs) + if use_past_kv_cache and hasattr(self, "_last_hf_cache"): + _hf_kv_cache = self._last_hf_cache or _hf_kv_cache + del self._last_hf_cache + final_logits = logits[:, -1, :] + + # Sample next token + penalty_tokens = ( + torch.stack(generated_token_ids, dim=1) + if _generate_from_embeds and generated_token_ids + else None + ) + if do_sample: + sampled_tokens = utils.sample_logits( + final_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + repetition_penalty=repetition_penalty, + tokens=penalty_tokens + if _generate_from_embeds + else (decoder_tokens if is_encoder_decoder else current_tokens), + ).to(self.cfg.device) + else: + sampled_tokens = utils.sample_logits( + final_logits, + temperature=0.0, + repetition_penalty=repetition_penalty, + tokens=penalty_tokens + if _generate_from_embeds + else (decoder_tokens if is_encoder_decoder else current_tokens), + ).to(self.cfg.device) + + # Handle EOS + if stop_at_eos: + sampled_tokens[finished_sequences] = eos_token_for_padding + finished_sequences.logical_or_( + torch.isin( + sampled_tokens.to(self.cfg.device), + torch.tensor(stop_tokens).to(self.cfg.device), + ) + ) + + # Update token sequences + if is_encoder_decoder: + assert decoder_tokens is not None + decoder_tokens = torch.cat([decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1) + elif _generate_from_embeds: + assert generated_token_ids is not None + generated_token_ids.append(sampled_tokens) + embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator] + assert embed_fn is not None + new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype) + current_tokens = torch.cat([current_tokens, new_embed], dim=1) + else: + current_tokens = torch.cat([current_tokens, sampled_tokens.unsqueeze(1)], dim=1) + + all_finished = bool(stop_at_eos and finished_sequences.all().item()) + + yield sampled_tokens, final_logits, all_finished + + if all_finished: + return + def generate( self, input: Union[str, List[str], torch.Tensor] = "", @@ -2355,188 +2541,41 @@ def generate( ) try: - for gen_step_idx in range(max_new_tokens): - # Get logits for next token - with torch.no_grad(): - if is_encoder_decoder: - logits = self( - encoder_input, - return_type="logits", - decoder_input=decoder_tokens, - ) - else: - forward_kwargs: Dict[str, Any] = {} - # Compute attention mask and position_ids for batched - # inputs with padding. HF models default to all-ones - # when no mask is given, which ignores padding tokens. - if ( - _is_batched_list - and self.tokenizer is not None - and self.tokenizer.pad_token_id is not None - ): - # Temp-swap to "left" so get_attention_mask scans - # for leading pads (matching the left-padded tokens). - _prev_side = self.tokenizer.padding_side - self.tokenizer.padding_side = "left" - attn_mask = utils.get_attention_mask( - self.tokenizer, - current_tokens, - prepend_bos=getattr(self.cfg, "default_prepend_bos", True), - ).to(self.cfg.device) - self.tokenizer.padding_side = _prev_side - forward_kwargs["attention_mask"] = attn_mask - # Adjust position_ids for left-padding so pad - # tokens don't consume real position embeddings. - position_ids = attn_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attn_mask == 0, 1) - forward_kwargs["position_ids"] = position_ids - # Pass multimodal inputs only on the first step — the vision - # encoder processes the image once, embedding it into the - # token sequence. This includes pixel_values plus any extra - # processor outputs (e.g. image_sizes for LlavaNext). - if gen_step_idx == 0: - if pixel_values is not None: - forward_kwargs["pixel_values"] = pixel_values - if multimodal_kwargs: - forward_kwargs.update(multimodal_kwargs) - if use_stateful_cache: - # Prefill sends arange(conv_kernel) (which both - # Mamba-1's length check and Mamba-2's value check - # accept as "not decode"). Decode sends the input - # token's actual sequence position — a fixed value - # above conv_kernel-1 silently picks the wrong - # slot for short prompts (see - # test_greedy_matches_hf_across_prompt_lengths). - # conv1d hooks fire only on prefill; HF bypasses - # the conv1d module on decode (see DepthwiseConv1DBridge). - forward_kwargs["cache_params"] = mamba_cache - forward_kwargs["use_cache"] = True - if gen_step_idx == 0: - cache_position = torch.arange( - 0, mamba_conv_kernel, device=self.cfg.device - ) - forward_kwargs["cache_position"] = cache_position - logits = self( - current_tokens, - return_type="logits", - **forward_kwargs, - ) - else: - # Token generated at step N-1 lives at - # sequence position prompt_len + gen_step_idx - 1 - input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1 - cache_position = torch.tensor( - [input_seq_pos], device=self.cfg.device - ) - forward_kwargs["cache_position"] = cache_position - if "position_ids" in forward_kwargs: - forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ - :, -1: - ] - logits = self( - current_tokens[:, -1:], - return_type="logits", - **forward_kwargs, - ) - elif use_past_kv_cache: - forward_kwargs["use_cache"] = True - if _hf_kv_cache is not None: - # Cached step: pass only the last token + cache - forward_kwargs["past_key_values"] = _hf_kv_cache - if "position_ids" in forward_kwargs: - forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ - :, -1: - ] - logits = self( - current_tokens[:, -1:], - return_type="logits", - **forward_kwargs, - ) - else: - # Step 0: full sequence, cache gets populated - logits = self( - current_tokens, - return_type="logits", - **forward_kwargs, - ) - else: - # No cache: full sequence every step - logits = self(current_tokens, return_type="logits", **forward_kwargs) - # Capture HF cache from forward() for next step. - if use_past_kv_cache and hasattr(self, "_last_hf_cache"): - _hf_kv_cache = self._last_hf_cache or _hf_kv_cache - del self._last_hf_cache - final_logits = logits[:, -1, :] - - # Collect logits if requested - if logits_seq_list is not None: - logits_seq_list.append(final_logits.clone()) - - # Sample next token - # For inputs_embeds, we can't pass the embeddings to freq/rep penalty, - # so use the generated_token_ids for penalty tracking - penalty_tokens = ( - torch.stack(generated_token_ids, dim=1) - if _generate_from_embeds and generated_token_ids - else None - ) - if do_sample: - sampled_tokens = utils.sample_logits( - final_logits, - top_k=top_k, - top_p=top_p, - temperature=temperature, - freq_penalty=freq_penalty, - repetition_penalty=repetition_penalty, - tokens=penalty_tokens - if _generate_from_embeds - else (decoder_tokens if is_encoder_decoder else current_tokens), - ).to(self.cfg.device) - else: - sampled_tokens = utils.sample_logits( - final_logits, - temperature=0.0, - repetition_penalty=repetition_penalty, - tokens=penalty_tokens - if _generate_from_embeds - else (decoder_tokens if is_encoder_decoder else current_tokens), - ).to(self.cfg.device) - - sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) - - # Handle EOS tokens for finished sequences - if stop_at_eos: - sampled_tokens[finished_sequences] = eos_token_for_padding - finished_sequences.logical_or_( - torch.isin( - sampled_tokens.to(self.cfg.device), - torch.tensor(stop_tokens).to(self.cfg.device), - ) - ) - - # Append sampled token to current sequence - if is_encoder_decoder: - decoder_tokens = torch.cat( - [decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1 - ) - elif _generate_from_embeds: - # For inputs_embeds: get the embedding of the new token and append - generated_token_ids.append(sampled_tokens) - embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator] - assert embed_fn is not None - new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype) - current_tokens = torch.cat([current_tokens, new_embed], dim=1) - else: - current_tokens = torch.cat( - [current_tokens, sampled_tokens.unsqueeze(1)], dim=1 - ) - - # Early stopping if all sequences finished - if stop_at_eos and finished_sequences.all(): - break + for sampled_tokens, final_logits, all_finished in self._generate_tokens( + current_tokens, + input_tokens, + batch_size, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + repetition_penalty=repetition_penalty, + stop_at_eos=stop_at_eos, + stop_tokens=stop_tokens, + eos_token_for_padding=eos_token_for_padding, + finished_sequences=finished_sequences, + use_past_kv_cache=use_past_kv_cache, + use_stateful_cache=use_stateful_cache, + mamba_cache=mamba_cache, + mamba_conv_kernel=mamba_conv_kernel, + is_encoder_decoder=is_encoder_decoder, + _is_batched_list=_is_batched_list, + _generate_from_embeds=_generate_from_embeds, + encoder_input=encoder_input if is_encoder_decoder else None, + decoder_tokens=decoder_tokens if is_encoder_decoder else None, + generated_token_ids=generated_token_ids if _generate_from_embeds else None, + pixel_values=pixel_values, + multimodal_kwargs=multimodal_kwargs if multimodal_kwargs else {}, + verbose=verbose, + ): + sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) + if logits_seq_list is not None: + logits_seq_list.append(final_logits.clone()) + if all_finished: + break finally: - # Clean up generate-only state even if an exception occurs, - # so _capture_hf_cache doesn't leak into subsequent forward() calls. self._capture_hf_cache = False if hasattr(self, "_last_hf_cache"): del self._last_hf_cache @@ -2544,7 +2583,8 @@ def generate( # Concatenate all sampled tokens sampled_tokens = torch.cat(sampled_tokens_list, dim=1) if is_encoder_decoder: - output_tokens = decoder_tokens + # Reconstruct full decoder sequence: start token + generated tokens + output_tokens = torch.cat([decoder_tokens[:, :1], sampled_tokens], dim=1) elif _generate_from_embeds: # For inputs_embeds, we only have the generated token IDs (no input token IDs) output_tokens = sampled_tokens @@ -2591,6 +2631,191 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ... else: # return_type == "tokens" return output_tokens + @torch.no_grad() + def generate_stream( + self, + input: Union[str, List[str], torch.Tensor] = "", + max_new_tokens: int = 10, + max_tokens_per_yield: int = 25, + stop_at_eos: bool = True, + eos_token_id: Optional[int] = None, + do_sample: bool = True, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: float = 1.0, + freq_penalty: float = 0.0, + repetition_penalty: float = 1.0, + use_past_kv_cache: bool = True, + prepend_bos: Optional[bool] = None, + padding_side: Optional[str] = None, + return_type: Optional[str] = "input", + verbose: bool = True, + ) -> Generator[Union[torch.Tensor, str], None, None]: + """Stream tokens from the model as they are generated. + + Yields batches of tokens progressively during generation rather than + waiting for the entire sequence. Uses the same core loop as generate(). + + Args: + input: Text string, list of strings, or tensor of tokens. + max_new_tokens: Maximum number of tokens to generate. + max_tokens_per_yield: Yield accumulated tokens every this many steps. + stop_at_eos: If True, stop when eos_token is produced. + eos_token_id: Token ID(s) for end of sentence. Defaults to tokenizer's. + do_sample: If True, sample; otherwise greedy. + top_k: Top-k sampling. None means no filtering. + top_p: Nucleus sampling threshold. + temperature: Sampling temperature. + freq_penalty: Frequency penalty for previous tokens. + repetition_penalty: HF-style repetition penalty (>1.0 discourages repeats). + use_past_kv_cache: Use KV caching for faster generation. + prepend_bos: Not applied (API compatibility). See generate() docstring. + padding_side: Which side to pad for batched list inputs. Left-padding + is forced internally for batched generation. + return_type: 'input' (match input type), 'str', or 'tokens'. + verbose: Show progress bar. + + Yields: + Token tensors [batch, seq_len] or strings, accumulated up to + max_tokens_per_yield tokens between yields. First yield includes + the input tokens; subsequent yields contain only new tokens. + """ + if prepend_bos is not None: + warnings.warn( + "prepend_bos is ignored during TransformerBridge.generate_stream(). " + "The HF model expects tokens with the tokenizer's default BOS handling.", + stacklevel=2, + ) + + # --- Input parsing (mirrors generate()) --- + _is_batched_list = isinstance(input, list) and len(input) > 1 + + if isinstance(input, str): + input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + input_type = "str" + elif isinstance(input, list): + if _is_batched_list: + _orig_ps = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + try: + input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + finally: + if _is_batched_list: + self.tokenizer.padding_side = _orig_ps + input_type = "list" + else: + input_tokens = input.to(self.cfg.device) + input_type = "tokens" + + if return_type == "input": + return_type = "str" if input_type in ["str", "list"] else "tokens" + + batch_size = input_tokens.shape[0] + + # --- EOS setup --- + stop_tokens: List[int] = [] + eos_token_for_padding = 0 + if stop_at_eos: + if eos_token_id is None: + assert ( + self.tokenizer.eos_token_id is not None + ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id" + eos_token_id = self.tokenizer.eos_token_id + if isinstance(eos_token_id, int): + stop_tokens = [eos_token_id] + eos_token_for_padding = eos_token_id + else: + stop_tokens = list(eos_token_id) + eos_token_for_padding = ( + self.tokenizer.eos_token_id + if self.tokenizer.eos_token_id is not None + else eos_token_id[0] + ) + + finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) + + # --- Cache setup --- + if use_past_kv_cache: + self._capture_hf_cache = True + + current_tokens = input_tokens.clone() + + # --- Streaming loop --- + # All yields are token tensors [batch, seq_len]. Each yield contains + # only the newly generated tokens since the previous yield (the first + # yield additionally prepends the input tokens for context). + accumulated_tokens: Optional[torch.Tensor] = None + tokens_since_last_yield = 0 + + def _maybe_decode( + tokens: torch.Tensor, + ) -> Union[torch.Tensor, str]: + if return_type == "str": + return self.tokenizer.decode(tokens[0], skip_special_tokens=True) + return tokens + + try: + for step_idx, (sampled_tokens, _, all_finished) in enumerate( + self._generate_tokens( + current_tokens, + input_tokens, + batch_size, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + freq_penalty=freq_penalty, + repetition_penalty=repetition_penalty, + stop_at_eos=stop_at_eos, + stop_tokens=stop_tokens, + eos_token_for_padding=eos_token_for_padding, + finished_sequences=finished_sequences, + use_past_kv_cache=use_past_kv_cache, + use_stateful_cache=False, + mamba_cache=None, + mamba_conv_kernel=0, + is_encoder_decoder=False, + _is_batched_list=_is_batched_list, + _generate_from_embeds=False, + encoder_input=None, + decoder_tokens=None, + generated_token_ids=None, + pixel_values=None, + multimodal_kwargs={}, + verbose=verbose, + ) + ): + new_tokens = sampled_tokens.unsqueeze(-1) + + if step_idx == 0: + accumulated_tokens = torch.cat([input_tokens, new_tokens], dim=-1) + tokens_since_last_yield = accumulated_tokens.shape[1] + else: + if accumulated_tokens is None: + accumulated_tokens = new_tokens + else: + accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) + tokens_since_last_yield += 1 + + if tokens_since_last_yield >= max_tokens_per_yield: + yield _maybe_decode(accumulated_tokens) + tokens_since_last_yield = 0 + accumulated_tokens = None + + if all_finished: + if accumulated_tokens is not None: + yield _maybe_decode(accumulated_tokens) + break + + # Yield remainder after loop completes without break + if accumulated_tokens is not None: + yield _maybe_decode(accumulated_tokens) + finally: + self._capture_hf_cache = False + if hasattr(self, "_last_hf_cache"): + del self._last_hf_cache + def hf_generate( self, input: str | list[str] | torch.Tensor = "",