diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index a6eb83a32676..7e52a4378e9d 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -166,7 +166,13 @@ def prepare_led_inputs_dict( if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) if decoder_attention_mask is None: - decoder_attention_mask = tf.cast(tf.math.not_equal(decoder_input_ids, config.pad_token_id), tf.int8) + decoder_attention_mask = tf.concat( + [ + tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8), + tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8), + ], + axis=-1, + ) return { "input_ids": input_ids, "attention_mask": attention_mask,