Fix memory regression in Seq2Seq example#9713
Conversation
|
|
||
| # set decoder_start_token_id for MBart | ||
| if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): | ||
| if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): |
There was a problem hiding this comment.
With the switch to fast tokenizers, this was actually not working.
| if isinstance(tokenizer, MBartTokenizer): | ||
| model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] | ||
| else: | ||
| model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.tgt_lang) |
There was a problem hiding this comment.
@patrickvonplaten Wondering if this the right way to do this for the fast tokenizer, as it has no lang_code_to_id.
There was a problem hiding this comment.
this should work, as MBartTokenizer is inherited from XLMRobertaTokenizer and in convert_tokens_to_ids it first checks if the token is in fairseq_tokens_to_ids dict and returns the id if it is. All lang tokens are added in the fairseq_tokens_to_ids dict
There was a problem hiding this comment.
@patil-suraj, I think @sgugger was wondering about the fast tokenizers, so MBartTokenizerFast, and those don't have the fairseq_tokens_to_ids dict. However, I think your approach is the correct one here @sgugger . MBartTokenizerFast adds all languages codes, called FAIRSEQ_LANGUAGE_CODES to the special tokens so that they can be converted just the way you did above -> LGTM!
| labels = trim_batch(labels, self.pad_token_id) | ||
| input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) | ||
|
|
||
| if isinstance(self.tokenizer, T5Tokenizer): |
There was a problem hiding this comment.
Sorry, yeah this was my bad. I didn't get why we've inputted both decoder_input_ids and labels previously - should have thought about it a bit more!
It's a tough decision on how to solve this long-term I think:
- there is no guarantee that
shift_tokens_rightis the same for all models, we already have three different methods (MBart, EncoderDecoder (not yet implemented but will be different), Bart & other enc-dec) - We don't really want to start an
if isinstance(self.tokenizer)-else cascade in the data collactor. - I don't really like to add a flag
_do_compute_lossto the encoder-decoder models because it breaks consistency that only passinglableswill trigger a loss calculation with all other models and adds more complexity to the user in that another config param has to be understood. - Another option would be to pass the models to the
__init__of the data collator and then to call:decoder_input_ids = model.shift_right(...)and forcing all functions to have the same name similar toprepare_inputs_for_generation(...) - I think overall, it would actually be better to get fully rid of the data collators all together and instead make use of
datasets.map(...)to prepare the inputs and just use the default datacollator as shown in => then we wouldn't have any of those problems and would have more control over what to input to the model
Maybe we should discuss this quickly offline
There was a problem hiding this comment.
I didn't get why we've inputted both decoder_input_ids and labels previously
we prepared both decoder_input_ids and labels because we need to pass decoder_input_ids to model explicitly when doing label smoothing as we don't pass labels to avoid calculating the loss twice.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Looks good to me for now! However, the script won't work with ProphetNet and LED at the moment I think.
As discussed a bit offline:
I don't think specific data collators are a good solution. One can fully prepare the data using datasets.map() function (I always do this in my seq2seq scripts) and then just use the default data collator that simply stacks the tensors correctly to a batch. We will run into more and more problems using specific data collators like the Seq2Seq ones because they can never handle all cases, e.g. it wouldn't work at the moment for LED either because LED needs to define a global_attention_mask as an input. Also thinking a bit more about the expansion to Vision and Speech for Transformers where we will have even more different kinds of inputs, I don't think that we should rely on a data collator to map the raw data to the model's expected input. It'll end up in a weird if-else cascade anyways. I think the responsibility should shift to the datasets.map(...) function to define all required inputs. So, regarding the PR, to keep bcp I'm fine with adding some kind of hack to Seq2SeqDatacollator to make it work, but I'm not really in favor of adding a boolean flag to the model. I'd rather like to deprecate the Seq2SeqDatacollator soon and just make all scripts use the default_data_collator - I think all scripts already do, but some use more data collators in addition.
Would be keen to hear @LysandreJik @stas00 @patil-suraj @thomwolf opinion here as well
|
I agree with Patrick here, the collator was added to encode the text and to prepare the The only thing we need IMO is to be able to prepare |
|
Note that this is fixing the old script with the old data collator. The new one will be fixed with the proper fix (once we agree on it and there seems to be a consensus on having a model with a The old |
LysandreJik
left a comment
There was a problem hiding this comment.
Great, LGTM! Good job finding the source of the issue!
What does this PR do?
This PR fixes the memory regression introduced when putting the
Seq2SeqTrainerinside the main library. The root of the memory regression comes from the fact that when doing label smoothing, we ended up computing the log softmax of the logits twice, once in the cross entropy loss, and a second time inside the label smoother.To fix this, the loss computation needs to be entirely done inside the label smoother, so the labels must be extracted from the batch before being passed to the model. As a result, the
decoder_input_idsmust be computed in theSeq2SeqDataCollatorand not the model for this to work. I've just reverted the code from #9343, I don't know if it actually matches what happens inside the models. Maybe we should have a method to compute thosedecoder_input_idsaccessible from inside of those models, or a flag to tell them whether to compute the loss or not (in this case, computing the loss will not only be slower, it will also trigger back the memory regression).The same fix will need to be applied to the
Seq2SeqDataCollatornow inside the library as well as the newrun_seq2seqscript, but I will do it once we have agreed on a long-term solution for the decoder input ids above.Fixes #9261