diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 2cf272f4aac1..eb328d83e9e7 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1715,10 +1715,9 @@ def test_dataset_conversion(self): model.train_on_batch(test_batch, test_batch_labels) def _test_xla_generate(self, **generate_kwargs): - def _generate_and_check_results(model, inputs_dict): - if "input_ids" in inputs_dict: - inputs = inputs_dict["input_ids"] - # make sure there are no pad tokens in prompt, which may trigger unwanted behavior + def _generate_and_check_results(model, inputs, is_input_ids): + # make sure there are no pad tokens in prompt, which may trigger unwanted behavior + if is_input_ids: if model.generation_config.pad_token_id is not None: if config.pad_token_id == 0: new_pad_token = model.generation_config.pad_token_id + 1 @@ -1727,10 +1726,6 @@ def _generate_and_check_results(model, inputs_dict): else: new_pad_token = None inputs = tf.where(inputs != model.generation_config.pad_token_id, inputs, new_pad_token) - elif "input_features" in inputs_dict: - inputs = inputs_dict["input_features"] - else: - raise ValueError("No valid generate input found in inputs_dict") generated = model.generate(inputs, **generate_kwargs).numpy() generate_xla = tf.function(model.generate, jit_compile=True) @@ -1753,12 +1748,20 @@ def _generate_and_check_results(model, inputs_dict): config.eos_token_id = None # Generate until max length config.do_sample = False + # extract the input to the model + is_input_ids = "input_ids" in inputs_dict + is_input_features = "input_features" in inputs_dict + if not (is_input_ids or is_input_features): + raise ValueError("No valid generate input found in inputs_dict") + inputs = inputs_dict["input_ids"] if is_input_ids else inputs_dict["input_features"] + # fix config for models with additional sequence-length limiting settings + seq_len = inputs.get_shape()[1] for var_name in ["max_position_embeddings", "max_target_positions"]: attr = getattr(config, var_name, None) - if attr is not None and attr < generate_kwargs["max_new_tokens"]: + if attr is not None and attr < seq_len + generate_kwargs["max_new_tokens"]: try: - setattr(config, var_name, generate_kwargs["max_new_tokens"]) + setattr(config, var_name, seq_len + generate_kwargs["max_new_tokens"]) except NotImplementedError: # xlnet will raise an exception when trying to set # max_position_embeddings. @@ -1767,10 +1770,10 @@ def _generate_and_check_results(model, inputs_dict): model = model_class(config) if model.supports_xla_generation: - _generate_and_check_results(model, inputs_dict) + _generate_and_check_results(model, inputs, is_input_ids) else: with self.assertRaises(ValueError): - _generate_and_check_results(model, inputs_dict) + _generate_and_check_results(model, inputs, is_input_ids) def test_xla_generate_fast(self): """