diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py index c75b43466af7..1d277a3bcd2d 100644 --- a/src/transformers/generation/streamers.py +++ b/src/transformers/generation/streamers.py @@ -16,6 +16,7 @@ from queue import Queue from typing import TYPE_CHECKING, Optional +# from transformers.generation.utils import (GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput) if TYPE_CHECKING: from ..models.auto import AutoTokenizer @@ -225,3 +226,108 @@ def __next__(self): raise StopIteration() else: return value + + +class OutputStreamer(BaseStreamer): + """ + Streams Output objects + """ + def __init__(self, + filter_func=None, + cache = None, + ): + if filter_func is None: + filter_func = self._filter_func + self.filter_func = filter_func + if cache is None: + cache = [] + self.cache = cache # incoming unprocessed outputs + + def _filter_func(self, value): + """ + Class-default behavior for self.filter_func. + self.filter_func will be called on each incoming value. Can be used to filter the stream to a particular + attribute on the value object, or to limit the stream to values meeting certain criteria. + """ + return value + + def process_incoming_value(self, value): + """ + Called on each incoming value + """ + return self.filter_func(value) + + def is_ready(self): + """ + Test whether the buffer is ready + """ + return len(self.cache) > 1 + + def on_ready(self): + """ + When the buffer is ready, flush it and do something with the values it was holding + """ + if len(self.cache) > 1: + values = self.cache[:] + elif len(self.cache) == 1: + values = self.cache[0] + values = [values] # put it in a list to be consistent + else: + raise ValueError("on_ready() called on an empty buffer. This should not happen. Report this error.") + self.cache = [] + return self.process_outgoing_values(values) + + def process_outgoing_values(self, values): + """ + What to do with the values that were previously in the buffer + """ + return values + + def put(self, value): + value = self.process_incoming_value(value) + if value is not None: + if isinstance(value, list): + self.cache.extend(value) + else: + self.cache.append(value) + + if self.is_ready(): + return self.on_ready() + + +class OutputIteratorStreamer(OutputStreamer): + def __init__(self, + filter_func=None, + cache = None, + queue=None, + timeout: Optional[float] = None, + ): + super().__init__(filter_func=filter_func, cache=cache) + if queue is None: + queue = Queue() + self.queue = queue # outgoing finalized outputs + self.timeout = timeout + self.stop_signal = None + + def process_outgoing_values(self, values): + """ + What to do with the values that were previously in the buffer + """ + self.queue.put(values) + + + def __iter__(self): + return self + + def __next__(self): + value = self.queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value + + def end(self): + # flush the cache if there's anything in it + if self.cache: + self.on_ready() + self.queue.put(self.stop_signal) \ No newline at end of file diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 79105667dbe0..3e90d6b999bb 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -358,6 +358,33 @@ class GenerationMixin: To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). """ + def _prepare_output( + self, *, + return_dict_in_generate, + **output_kargs): + if return_dict_in_generate: + if self.config.is_encoder_decoder: + cls = GenerateEncoderDecoderOutput + else: + cls =GenerateDecoderOnlyOutput + if 'decoder_attentions' in output_kargs: + output_kargs['attentions'] = output_kargs.pop('decoder_attentions') + if 'decoder_hidden_states' in output_kargs: + output_kargs['hidden_states'] = output_kargs.pop('decoder_hidden_states') + + if 'encoder_attentions' in output_kargs: + output_kargs.pop('encoder_attentions') + if 'encoder_hidden_states' in output_kargs: + output_kargs.pop('encoder_hidden_states') + if 'cross_attentions' in output_kargs: + output_kargs.pop('cross_attentions') + + outv = cls(**output_kargs) + else: + outv = output_kargs['sequences'] + return outv + + def prepare_inputs_for_generation(self, *args, **kwargs): raise NotImplementedError( "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." @@ -1858,7 +1885,12 @@ def generate( input_ids = self.heal_tokens(input_ids, tokenizer) if streamer is not None: - streamer.put(input_ids.cpu()) + output_stub = self._prepare_output( + return_dict_in_generate=generation_config.return_dict_in_generate, + sequences=input_ids, + # no scores/logits/attention/hidden here because they haven't been computed yet. + ) + streamer.put(output_stub) # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] @@ -2546,6 +2578,11 @@ def _contrastive_search( cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + # initialize variables for self._prepare_output + encoder_attentions = encoder_hidden_states = None + next_step_cross_attentions = () + next_step_decoder_attentions = () + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None @@ -2781,8 +2818,6 @@ def _contrastive_search( # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration if self.config.is_encoder_decoder: - next_step_cross_attentions = () - next_step_decoder_attentions = () if output_attentions: for layer in outputs.cross_attentions: layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] @@ -2819,7 +2854,21 @@ def _contrastive_search( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: - streamer.put(next_tokens.cpu()) + output_stub = self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=next_tokens, + scores=(processed_logit_for_next_step,), # (scores,), + logits=(processed_logit_for_next_step,), + # I think there's an issue with the contrastive sampling implementation that is currently returning the same values for logits as scores #(logits[selected_idx,:],), #(logit_for_next_step,), # `logit_for_next_step`: values don't match, `logits`: shapes don't match + encoder_attentions=None, # probably doesn't make sense to stream this + encoder_hidden_states=None, # probably doesn't make sense to stream this + decoder_attentions=(next_step_decoder_attentions,), + # ([0],),# very concerning that if I set this to `([0],)` my tests don't fail + cross_attentions=(next_step_cross_attentions,), + decoder_hidden_states=(next_decoder_hidden_states,), + past_key_values=None, # probably doesn't make sense to stream this + ) + streamer.put(output_stub) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, @@ -2851,29 +2900,18 @@ def _contrastive_search( past_key_values.append(tuple(layer_past_key_values)) model_kwargs["past_key_values"] = tuple(past_key_values) - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids + return self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values") + ) def _sample( self, @@ -2934,6 +2972,10 @@ def _sample( cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + # initialize variables for self._prepare_output(...) + encoder_attentions = encoder_hidden_states = None + next_decoder_attentions = next_cross_attentions = next_decoder_hidden_states = None + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None @@ -3006,7 +3048,19 @@ def _sample( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: - streamer.put(next_tokens.cpu()) + output_stub = self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=next_tokens, + scores=(next_token_scores,), + logits=(next_token_logits,), + encoder_attentions=None, + encoder_hidden_states=None, + decoder_attentions=(next_decoder_attentions,), + cross_attentions=(next_cross_attentions,), + decoder_hidden_states=(next_decoder_hidden_states,), + past_key_values=None, + ) + streamer.put(output_stub) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, @@ -3024,30 +3078,18 @@ def _sample( if streamer is not None: streamer.end() - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids + return self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values") + ) def _temporary_reorder_cache(self, past_key_values, beam_idx): """ diff --git a/tests/generation/test_streamers.py b/tests/generation/test_streamers.py index c82a5e99e0de..3f076f2d3737 100644 --- a/tests/generation/test_streamers.py +++ b/tests/generation/test_streamers.py @@ -17,11 +17,21 @@ from queue import Empty from threading import Thread +from collections import Counter +import copy +import random +import unittest +import pytest + +from transformers.generation.streamers import OutputIteratorStreamer + from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available from transformers.testing_utils import CaptureStdout, require_torch, torch_device from ..test_modeling_common import ids_tensor +import lovely_tensors as lt +lt.monkey_patch() if is_torch_available(): import torch @@ -120,3 +130,285 @@ def test_iterator_streamer_timeout(self): streamer_text = "" for new_text in streamer: streamer_text += new_text + + + +def nested_tensor_equality(left, right): + """ + Recursively check equality of tensors nested in tuple of tuples + """ + assert type(left) == type(right) + assert len(left) == len(right) + if isinstance(left, torch.Tensor): + assert torch.equal(left, right) + else: + for left2, right2 in zip(left, right): + assert nested_tensor_equality(left2, right2) + return True + + +@require_torch +class TestOutputIteratorStreamer: + + def _setup(self, + model="hf-internal-testing/tiny-random-gpt2", + # assistant_model, + do_sample=False, + top_k=None, + penalty_alpha=None, + output_scores=False, + output_logits=False, + output_attentions=False, + max_new_tokens=10, + return_dict_in_generate=True, + output_hidden_states=False, + ): + model = AutoModelForCausalLM.from_pretrained(model).to(torch_device) + model.config.eos_token_id = -1 + print(model.config) + + generation_kwargs = dict( + # input_ids=input_ids, + max_new_tokens=max_new_tokens, + return_dict_in_generate=return_dict_in_generate, + do_sample=do_sample, + top_k=top_k, + penalty_alpha=penalty_alpha, + output_scores=output_scores, + output_logits=output_logits, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + # if assistant_model: + # # attentions acting funny. suppress for now + # if not output_attentions: + # generation_kwargs['assistant_model'] = copy.deepcopy(model) + # generation_kwargs['assistant_model'].config.eos_token_id = 999 # assistant model needs to have a valid eos_token_id I think + + ### dmarx Force behaviors here for development ########################################### + # lol maybe these should just be separate tests.... + + # output attentions for... + # ...greedy decoding + # generation_kwargs['output_attentions'] = False + # if (not generation_kwargs['do_sample']) and (generation_kwargs['penalty_alpha'] is None): + # generation_kwargs['output_attentions'] = True + # + # # ...multinomial sampling + # if (generation_kwargs['do_sample']) and (generation_kwargs['penalty_alpha'] is None): + # generation_kwargs['output_attentions'] = True + # + # # output attentions for contrastive decoding + # if (generation_kwargs['do_sample']) and (generation_kwargs['penalty_alpha'] is not None) and (generation_kwargs['top_k'] is not None) : + # generation_kwargs['output_attentions'] = True + #### /dmarx ############################################################################## + + print(generation_kwargs) # easier than decoding pytest parameterization shorthand on error + + input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device) + generation_kwargs['input_ids'] = input_ids + + baseline_kwargs = copy.deepcopy(generation_kwargs) + test_kwargs = copy.deepcopy(generation_kwargs) + + seed = random.randint(0, int(1e9)) + torch.manual_seed(seed) + baseline_outputs = model.generate(**baseline_kwargs) + print("baseline_outputs") + print(baseline_outputs) + + streamer = OutputIteratorStreamer() + test_kwargs['streamer'] = streamer + torch.manual_seed(seed) + thread = Thread(target=model.generate, kwargs=test_kwargs) + thread.start() + + outputs = {'sequences':torch.Tensor()} + for attr_name in ( + 'scores', 'logits', + 'attentions', 'encoder_attentions', 'decoder_attentions', 'cross_attentions', + 'hidden_states', 'encoder_hidden_states', 'decoder_hidden_states', + #'past_key_values' # uh... let's just say we're not going to support streaming the cache. + ): + if hasattr(baseline_outputs, attr_name): + if getattr(baseline_outputs, attr_name) is not None: + #print(attr_name) + #print(getattr(baseline_outputs, attr_name)) + outputs[attr_name] = () + + return baseline_outputs, outputs, streamer + + + # @pytest.mark.parametrize("do_sample,top_k", [(False,None), (True,4)]) + # @pytest.mark.parametrize("penalty_alpha", [None, 0.6]) + # @pytest.mark.parametrize("output_scores", [False, True]) + # @pytest.mark.parametrize("output_logits", [False, True]) + # @pytest.mark.parametrize("output_attentions", [False, True]) + # @pytest.mark.parametrize("model", ["hf-internal-testing/tiny-random-gpt2", "hf-internal-testing/tiny-random-bert", "hf-internal-testing/tiny-random-bart"]) # decoder, encoder, encoder-decoder + #@pytest.mark.parametrize("assistant_model", [False, True]) # having issues + def check_outputs_match(self, + *, + model="hf-internal-testing/tiny-random-gpt2", + #assistant_model, + do_sample=False, + top_k=None, + max_new_tokens=10, + penalty_alpha=None, + output_scores=False, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + ): + + baseline_outputs, outputs, streamer = self._setup( + model=model, + # assistant_model=assistant_model, + do_sample=do_sample, + top_k=top_k, + penalty_alpha=penalty_alpha, + output_scores=output_scores, + output_logits=output_logits, + output_attentions=output_attentions, + max_new_tokens=max_new_tokens, + return_dict_in_generate=True, + output_hidden_states=output_hidden_states, + ) + + n_times_field_extended = Counter() + for answer in streamer: + #if isinstance(answer, list): + assert isinstance(answer, list) + for output_object in answer: + for output_name in outputs.keys(): + #print(output_name) + new_values = getattr(output_object, output_name) + if (new_values is not None) and (len(new_values) > 0): + + #print(type(outputs[output_name]), type(new_values)) + if output_name == 'sequences': + new_values = new_values.cpu() # fml.... + if new_values.ndim == 1: + new_values = new_values.unsqueeze(0) + outputs[output_name] = torch.cat([outputs[output_name], new_values], axis=-1) + else: + outputs[output_name] += new_values # tuples gonna tuple... + + print(outputs) + for output_name in outputs.keys(): + print(output_name) + baseline_values = getattr(baseline_outputs, output_name) + if isinstance(baseline_values, torch.Tensor): + baseline_values = baseline_values.cpu() + #assert (baseline_values is not None) and (baseline_values != tuple()) + assert (baseline_values is not None) + #assert type(baseline_values) == type(getattr(output_object, output_name)) + #assert n_times_field_extended[output_name] > 1 # make sure we're not just comparing to the final output tensor + # TODO: pick a better "are these literally the same object" test + + #if not isinstance(baseline_values, torch.Tensor): + # baseline_values = torch.cat(baseline_values, axis=-1) + target_values = outputs[output_name] + #assert baseline_values.shape == target_values.shape + print("baseline", baseline_values) + print("target", target_values) + assert type(baseline_values) == type(target_values) + assert len(baseline_values) == len(target_values) + + + # attention/hidden = tuples of tuples + assert nested_tensor_equality(baseline_values, target_values) + + # @pytest.mark.parametrize("do_sample,top_k", [(False,None), (True,4)]) + # @pytest.mark.parametrize("penalty_alpha", [None, 0.6]) + def check_ids_only_match(self, + do_sample=False, + top_k=None, + penalty_alpha=None, + max_new_tokens=10, + model="hf-internal-testing/tiny-random-gpt2", + ): + baseline_values, outputs, streamer = self._setup( + model=model, + # assistant_model=assistant_model, + do_sample=do_sample, + top_k=top_k, + penalty_alpha=penalty_alpha, + max_new_tokens=max_new_tokens, + return_dict_in_generate=False, + output_scores=False, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + ) + + target_values = torch.Tensor() + for answer in streamer: + assert isinstance(answer, list) + for output_object in answer: + new_ids = output_object.cpu() + if new_ids.ndim == 1: + new_ids = new_ids.unsqueeze(0) + target_values = torch.cat([target_values, new_ids], axis=-1) + + assert baseline_values.shape == target_values.shape + assert baseline_values.tolist() == target_values.tolist() + + def test_greedy_ids_only(self): + self.check_ids_only_match(do_sample=False) + + def test_multinomial_ids_only(self): + self.check_ids_only_match(do_sample=True) + + def test_contrastive_ids_only(self): + self.check_ids_only_match(do_sample=False, penalty_alpha=0.6, top_k=4) + + #def test_assisted_ids_only(self): + # + + @pytest.mark.parametrize("output_scores", [False, True]) + @pytest.mark.parametrize("output_logits", [False, True]) + #@pytest.mark.parametrize("output_attentions", [False, True]) + @pytest.mark.parametrize("output_attentions", [False]) + def test_greedy_outputs(self, + output_scores, + output_logits, + output_attentions, + ): + self.check_outputs_match( + do_sample=False, + output_scores=output_scores, + output_logits=output_logits, + output_attentions=output_attentions) + + @pytest.mark.parametrize("output_scores", [False, True]) + @pytest.mark.parametrize("output_logits", [False, True]) + #@pytest.mark.parametrize("output_attentions", [False, True]) + @pytest.mark.parametrize("output_attentions", [False]) + def test_multinomial_outputs(self, + output_scores, + output_logits, + output_attentions, + ): + self.check_outputs_match( + do_sample=True, + output_scores=output_scores, + output_logits=output_logits, + output_attentions=output_attentions) + + @pytest.mark.parametrize("output_scores", [False, True]) + @pytest.mark.parametrize("output_logits", [False, True]) + # TODO: reactivate fixtures for logits and attentions + @pytest.mark.parametrize("output_attentions", [False]) + #@pytest.mark.parametrize("output_attentions", [False, True]) + def test_contrastive_outputs(self, + output_scores, + output_logits, + output_attentions, + ): + self.check_outputs_match( + do_sample=False, + penalty_alpha=0.6, + top_k=4, + output_scores=output_scores, + output_logits=output_logits, + output_attentions=output_attentions) \ No newline at end of file