Skip to content
426 changes: 172 additions & 254 deletions tests/generation/test_utils.py

Large diffs are not rendered by default.

29 changes: 7 additions & 22 deletions tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,28 +283,6 @@ def is_pipeline_test_to_skip(

return False

# overwrite from GenerationTesterMixin to solve problem
# with conflicting random seeds
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.attention_type = "original_full"

input_ids = inputs_dict.pop(self.input_name)
_ = inputs_dict.pop("attention_mask", None)
_ = inputs_dict.pop("decoder_input_ids", None)
_ = inputs_dict.pop("decoder_attention_mask", None)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)

# cut to half length & take max batch_size 3
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:batch_size, :sequence_length]
attention_mask = attention_mask[:batch_size, :sequence_length]

if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, inputs_dict

def setUp(self):
self.model_tester = BigBirdPegasusModelTester(self)
self.config_tester = ConfigTester(self, config_class=BigBirdPegasusConfig)
Expand Down Expand Up @@ -485,6 +463,13 @@ def test_for_change_to_full_attn(self):
def test_load_save_without_tied_weights(self):
pass

def test_generate_with_head_masking(self):
# overwritten to temporarily switch the attention type to `original_full`
original_self_attention_type = self.model_tester.attention_type
self.model_tester.attention_type = "original_full"
super().test_generate_with_head_masking()
self.model_tester.attention_type = original_self_attention_type


@require_torch
@require_sentencepiece
Expand Down
2 changes: 1 addition & 1 deletion tests/models/chameleon/test_modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
Copy link
Copy Markdown
Contributor Author

@gante gante Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

common pattern: input_mask, which was then passed around as attention_mask, was a torch.float32 instead of a torch.long 👀


sequence_labels = None
token_labels = None
Expand Down
2 changes: 1 addition & 1 deletion tests/models/cohere/test_modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))

token_type_ids = None
if self.use_token_type_ids:
Expand Down
1 change: 0 additions & 1 deletion tests/models/dac/test_modeling_dac.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ class DacModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_headmasking = False
test_resize_embeddings = False
pipeline_model_mapping = {"feature-extraction": DacModel} if is_torch_available() else {}
input_name = "input_values"

def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
# model does not have attention and does not support returning hidden states
Expand Down
1 change: 0 additions & 1 deletion tests/models/encodec/test_modeling_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
test_headmasking = False
test_resize_embeddings = False
pipeline_model_mapping = {"feature-extraction": EncodecModel} if is_torch_available() else {}
input_name = "input_values"

def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
# model does not have attention and does not support returning hidden states
Expand Down
2 changes: 1 addition & 1 deletion tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))

token_type_ids = None
if self.use_token_type_ids:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/granite/test_modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))

token_type_ids = None
if self.use_token_type_ids:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/granitemoe/test_modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))

token_type_ids = None
if self.use_token_type_ids:
Expand Down
8 changes: 3 additions & 5 deletions tests/models/led/test_modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,11 @@ def test_global_attention(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_global_attention(*config_and_inputs)

def _get_input_ids_and_config(self, batch_size=2):
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(
self, batch_size=batch_size
)
def prepare_config_and_inputs_for_generate(self, *args, **kwargs):
config, inputs_dict = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
# LED computes attention scores based on mask indices if `is_global`
inputs_dict.pop("global_attention_mask")
return config, input_ids, attention_mask, inputs_dict
return config, inputs_dict

# LEDForSequenceClassification does not support inputs_embeds
def test_inputs_embeds(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))

token_type_ids = None
if self.use_token_type_ids:
Expand Down
1 change: 0 additions & 1 deletion tests/models/mimi/test_modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
test_headmasking = False
test_resize_embeddings = False
test_torchscript = False
input_name = "input_values"

def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
# model does support returning hidden states
Expand Down
2 changes: 1 addition & 1 deletion tests/models/mistral/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))

token_type_ids = None
if self.use_token_type_ids:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/mixtral/test_modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))

token_type_ids = None
if self.use_token_type_ids:
Expand Down
Loading