diff --git a/nematus/transformer_inference.py b/nematus/transformer_inference.py index 3e4b3ce2..83d17c89 100644 --- a/nematus/transformer_inference.py +++ b/nematus/transformer_inference.py @@ -128,11 +128,12 @@ def generate_initial_memories(self, batch_size, beam_size): with tf.compat.v1.name_scope(self._scope): state_size = self.config.state_size memories = {} + temps = tf.zeros([batch_size, 0, state_size] for layer_id in range(1, self.config.transformer_dec_depth + 1): memories['layer_{:d}'.format(layer_id)] = { \ - 'keys': tf.tile(tf.zeros([batch_size, 0, state_size]), + 'keys': tf.tile(temps, [beam_size, 1, 1]), - 'values': tf.tile(tf.zeros([batch_size, 0, state_size]), + 'values': tf.tile(temps, [beam_size, 1, 1]) } return memories