From 0e10961ae01d00d3752f4cbdfb4b06e0a7854a47 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 11 Jan 2021 13:48:37 +0100 Subject: [PATCH] fix tf led pt test --- tests/test_modeling_tf_led.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index a6eb83a32676..7e52a4378e9d 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -166,7 +166,13 @@ def prepare_led_inputs_dict( if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) if decoder_attention_mask is None: - decoder_attention_mask = tf.cast(tf.math.not_equal(decoder_input_ids, config.pad_token_id), tf.int8) + decoder_attention_mask = tf.concat( + [ + tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8), + tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8), + ], + axis=-1, + ) return { "input_ids": input_ids, "attention_mask": attention_mask,