Skip to content

Fix memory regression in Seq2Seq example#9713

Merged
sgugger merged 4 commits intomasterfrom
fix_memory_regression
Jan 21, 2021
Merged

Fix memory regression in Seq2Seq example#9713
sgugger merged 4 commits intomasterfrom
fix_memory_regression

Conversation

@sgugger
Copy link
Copy Markdown
Collaborator

@sgugger sgugger commented Jan 21, 2021

What does this PR do?

This PR fixes the memory regression introduced when putting the Seq2SeqTrainer inside the main library. The root of the memory regression comes from the fact that when doing label smoothing, we ended up computing the log softmax of the logits twice, once in the cross entropy loss, and a second time inside the label smoother.

To fix this, the loss computation needs to be entirely done inside the label smoother, so the labels must be extracted from the batch before being passed to the model. As a result, the decoder_input_ids must be computed in the Seq2SeqDataCollator and not the model for this to work. I've just reverted the code from #9343, I don't know if it actually matches what happens inside the models. Maybe we should have a method to compute those decoder_input_ids accessible from inside of those models, or a flag to tell them whether to compute the loss or not (in this case, computing the loss will not only be slower, it will also trigger back the memory regression).

The same fix will need to be applied to the Seq2SeqDataCollator now inside the library as well as the new run_seq2seq script, but I will do it once we have agreed on a long-term solution for the decoder input ids above.

Fixes #9261


# set decoder_start_token_id for MBart
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the switch to fast tokenizers, this was actually not working.

if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.tgt_lang)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten Wondering if this the right way to do this for the fast tokenizer, as it has no lang_code_to_id.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should work, as MBartTokenizer is inherited from XLMRobertaTokenizer and in convert_tokens_to_ids it first checks if the token is in fairseq_tokens_to_ids dict and returns the id if it is. All lang tokens are added in the fairseq_tokens_to_ids dict

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj, I think @sgugger was wondering about the fast tokenizers, so MBartTokenizerFast, and those don't have the fairseq_tokens_to_ids dict. However, I think your approach is the correct one here @sgugger . MBartTokenizerFast adds all languages codes, called FAIRSEQ_LANGUAGE_CODES to the special tokens so that they can be converted just the way you did above -> LGTM!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah, I see. My bad.

Comment thread examples/seq2seq/utils.py
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)

if isinstance(self.tokenizer, T5Tokenizer):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, yeah this was my bad. I didn't get why we've inputted both decoder_input_ids and labels previously - should have thought about it a bit more!

It's a tough decision on how to solve this long-term I think:

  • there is no guarantee that shift_tokens_right is the same for all models, we already have three different methods (MBart, EncoderDecoder (not yet implemented but will be different), Bart & other enc-dec)
  • We don't really want to start an if isinstance(self.tokenizer)-else cascade in the data collactor.
  • I don't really like to add a flag _do_compute_loss to the encoder-decoder models because it breaks consistency that only passing lables will trigger a loss calculation with all other models and adds more complexity to the user in that another config param has to be understood.
  • Another option would be to pass the models to the __init__ of the data collator and then to call: decoder_input_ids = model.shift_right(...) and forcing all functions to have the same name similar to prepare_inputs_for_generation(...)
  • I think overall, it would actually be better to get fully rid of the data collators all together and instead make use of datasets.map(...) to prepare the inputs and just use the default datacollator as shown in
    def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
    => then we wouldn't have any of those problems and would have more control over what to input to the model

Maybe we should discuss this quickly offline

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get why we've inputted both decoder_input_ids and labels previously

we prepared both decoder_input_ids and labels because we need to pass decoder_input_ids to model explicitly when doing label smoothing as we don't pass labels to avoid calculating the loss twice.

Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me for now! However, the script won't work with ProphetNet and LED at the moment I think.

As discussed a bit offline:

I don't think specific data collators are a good solution. One can fully prepare the data using datasets.map() function (I always do this in my seq2seq scripts) and then just use the default data collator that simply stacks the tensors correctly to a batch. We will run into more and more problems using specific data collators like the Seq2Seq ones because they can never handle all cases, e.g. it wouldn't work at the moment for LED either because LED needs to define a global_attention_mask as an input. Also thinking a bit more about the expansion to Vision and Speech for Transformers where we will have even more different kinds of inputs, I don't think that we should rely on a data collator to map the raw data to the model's expected input. It'll end up in a weird if-else cascade anyways. I think the responsibility should shift to the datasets.map(...) function to define all required inputs. So, regarding the PR, to keep bcp I'm fine with adding some kind of hack to Seq2SeqDatacollator to make it work, but I'm not really in favor of adding a boolean flag to the model. I'd rather like to deprecate the Seq2SeqDatacollator soon and just make all scripts use the default_data_collator - I think all scripts already do, but some use more data collators in addition.

Would be keen to hear @LysandreJik @stas00 @patil-suraj @thomwolf opinion here as well

@patil-suraj
Copy link
Copy Markdown
Contributor

patil-suraj commented Jan 21, 2021

I agree with Patrick here, the collator was added to encode the text and to prepare the decoder_input_ids and labels, replace pad with 100 etc. Now we could encode and prepare labels in datasets.map(...) so collator won't be needed anymore.

The only thing we need IMO is to be able to prepare decoder_input_ids outside of the model for label smoothing as Sylvain said. Could we maybe make the add shift_right method to every s2s model to able to prepare the decoder_input_ids outside of the model ?

@sgugger
Copy link
Copy Markdown
Collaborator Author

sgugger commented Jan 21, 2021

Note that this is fixing the old script with the old data collator. The new one will be fixed with the proper fix (once we agree on it and there seems to be a consensus on having a model with a shift_right method) but is still necessary to do dynamic padding. The Dataset.map method is very nice for static things but when you want to pad to the length of the biggest sample in the batch, you need a special data collator, especially if it has to pad special keys like "labels", "decoder_input_ids"...

The old Seq2SeqDataCollator in the utils file will be removed in a couple of weeks when the new seq2seq example is perfectly running, so I think it's fine to merge the quick hack in the meantime :-)

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, LGTM! Good job finding the source of the issue!

@sgugger sgugger merged commit 5f80c15 into master Jan 21, 2021
@sgugger sgugger deleted the fix_memory_regression branch January 21, 2021 17:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[seq2seq] memory regression

4 participants