Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,10 +1715,9 @@ def test_dataset_conversion(self):
model.train_on_batch(test_batch, test_batch_labels)

def _test_xla_generate(self, **generate_kwargs):
def _generate_and_check_results(model, inputs_dict):
if "input_ids" in inputs_dict:
inputs = inputs_dict["input_ids"]
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
def _generate_and_check_results(model, inputs, is_input_ids):
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
if is_input_ids:
if model.generation_config.pad_token_id is not None:
if config.pad_token_id == 0:
new_pad_token = model.generation_config.pad_token_id + 1
Expand All @@ -1727,10 +1726,6 @@ def _generate_and_check_results(model, inputs_dict):
else:
new_pad_token = None
inputs = tf.where(inputs != model.generation_config.pad_token_id, inputs, new_pad_token)
elif "input_features" in inputs_dict:
inputs = inputs_dict["input_features"]
else:
raise ValueError("No valid generate input found in inputs_dict")

generated = model.generate(inputs, **generate_kwargs).numpy()
generate_xla = tf.function(model.generate, jit_compile=True)
Expand All @@ -1753,12 +1748,20 @@ def _generate_and_check_results(model, inputs_dict):
config.eos_token_id = None # Generate until max length
config.do_sample = False

# extract the input to the model
is_input_ids = "input_ids" in inputs_dict
is_input_features = "input_features" in inputs_dict
if not (is_input_ids or is_input_features):
raise ValueError("No valid generate input found in inputs_dict")
inputs = inputs_dict["input_ids"] if is_input_ids else inputs_dict["input_features"]

# fix config for models with additional sequence-length limiting settings
seq_len = inputs.get_shape()[1]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if input_features also has its sequence length at 1. Just submitting this quickly before I don't have time.

for var_name in ["max_position_embeddings", "max_target_positions"]:
attr = getattr(config, var_name, None)
if attr is not None and attr < generate_kwargs["max_new_tokens"]:
if attr is not None and attr < seq_len + generate_kwargs["max_new_tokens"]:
try:
setattr(config, var_name, generate_kwargs["max_new_tokens"])
setattr(config, var_name, seq_len + generate_kwargs["max_new_tokens"])
except NotImplementedError:
# xlnet will raise an exception when trying to set
# max_position_embeddings.
Expand All @@ -1767,10 +1770,10 @@ def _generate_and_check_results(model, inputs_dict):
model = model_class(config)

if model.supports_xla_generation:
_generate_and_check_results(model, inputs_dict)
_generate_and_check_results(model, inputs, is_input_ids)
else:
with self.assertRaises(ValueError):
_generate_and_check_results(model, inputs_dict)
_generate_and_check_results(model, inputs, is_input_ids)

def test_xla_generate_fast(self):
"""
Expand Down