Skip to content

[TFBart-like models] Problem with tf saving #9313

@patrickvonplaten

Description

@patrickvonplaten

Context

Usually, encoder-decoder models require both input_ids and decoder_input_ids in order to do one forward pass. If one e.g. only passes the input_ids to TFT5 -> the model will complain:

from transformers import TFT5ForConditionalGeneration
import tensorflow as tf
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")

model(input_ids=tf.convert_to_tensor([10 * [2]])) # => will result in error saying `decoder_input_ids` have to be provided which is expected and correct

Now TFBart is a bit special in that it automatically generates the decoder_input_ids if they are not passed -> so that the above example would not throw an error for TFBartForConditionalGeneration.

The reason for this is this line:

inputs["decoder_input_ids"] = shift_tokens_right(

-> it automatically creates the decoder_input_ids from the input_ids if they are not provided. This is however more a hack than a good solution IMO. Soon we want to decouple the Bart-like models from each other and it would be good to delete this line from at least new Bart-like models. Now the problem.

Problem:

The problem is now that if we delete these lines from Bart, then the tf.saved_model.save(model, tmpdirname) function does not work anymore. To reproduce:
Go into master and comment out this if statement in TFBart:

inputs["decoder_input_ids"] = shift_tokens_right(
.

Then run the following code:

from transformers import TFBartForConditionalGeneration
import tempfile
import tensorflow as tf

model = TFBartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random")

input_ids = tf.convert_to_tensor([10 * [1]])
decoder_input_ids = tf.convert_to_tensor([10 * [8]])
inputs_dict = {"input_ids": input_ids, "decoder_input_ids": decoder_input_ids}


logits = model(inputs_dict).logits

model._saved_model_inputs_spec = None
model._set_save_spec(inputs_dict)

with tempfile.TemporaryDirectory() as tmpdirname:
    tf.saved_model.save(model, tmpdirname)
    model = tf.keras.models.load_model(tmpdirname)
    logits_2 = model(inputs_dict)["logits"]

=> the code will throw an error, but it should not! It seems like there is a weird naming mismatch between input_ids of TFBartDecoder and the decoder_input_ids in TFBartModel...@jplu I'd be thrilled if you could take a look at this and see how it can be solved.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions