-
Notifications
You must be signed in to change notification settings - Fork 33k
Fix memory regression in Seq2Seq example #9713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ | |
| AutoTokenizer, | ||
| HfArgumentParser, | ||
| MBartTokenizer, | ||
| MBartTokenizerFast, | ||
| Seq2SeqTrainer, | ||
| Seq2SeqTrainingArguments, | ||
| set_seed, | ||
|
|
@@ -220,11 +221,14 @@ def main(): | |
| data_args.eval_beams = model.config.num_beams | ||
|
|
||
| # set decoder_start_token_id for MBart | ||
| if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): | ||
| if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): | ||
| assert ( | ||
| data_args.tgt_lang is not None and data_args.src_lang is not None | ||
| ), "mBart requires --tgt_lang and --src_lang" | ||
| model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] | ||
| if isinstance(tokenizer, MBartTokenizer): | ||
| model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] | ||
| else: | ||
| model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.tgt_lang) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten Wondering if this the right way to do this for the fast tokenizer, as it has no
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should work, as
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patil-suraj, I think @sgugger was wondering about the fast tokenizers, so
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aah, I see. My bad. |
||
|
|
||
| if model_args.freeze_embeds: | ||
| freeze_embeds(model) | ||
|
|
@@ -284,7 +288,9 @@ def main(): | |
| args=training_args, | ||
| train_dataset=train_dataset, | ||
| eval_dataset=eval_dataset, | ||
| data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), | ||
| data_collator=Seq2SeqDataCollator( | ||
| tokenizer, data_args, model.config.decoder_start_token_id, training_args.tpu_num_cores | ||
| ), | ||
| compute_metrics=compute_metrics_fn, | ||
| tokenizer=tokenizer, | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -33,8 +33,9 @@ | |||
| from torch.utils.data import Dataset, Sampler | ||||
|
|
||||
| from sentence_splitter import add_newline_to_end_of_each_sentence | ||||
| from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer | ||||
| from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer | ||||
| from transformers.file_utils import cached_property | ||||
| from transformers.models.bart.modeling_bart import shift_tokens_right | ||||
|
|
||||
|
|
||||
| try: | ||||
|
|
@@ -274,9 +275,10 @@ def collate_fn(self, batch) -> Dict[str, torch.Tensor]: | |||
|
|
||||
|
|
||||
| class Seq2SeqDataCollator: | ||||
| def __init__(self, tokenizer, data_args, tpu_num_cores=None): | ||||
| def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None): | ||||
| self.tokenizer = tokenizer | ||||
| self.pad_token_id = tokenizer.pad_token_id | ||||
| self.decoder_start_token_id = decoder_start_token_id | ||||
| assert ( | ||||
| self.pad_token_id is not None | ||||
| ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." | ||||
|
|
@@ -304,9 +306,15 @@ def __call__(self, batch) -> Dict[str, torch.Tensor]: | |||
| labels = trim_batch(labels, self.pad_token_id) | ||||
| input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) | ||||
|
|
||||
| if isinstance(self.tokenizer, T5Tokenizer): | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, yeah this was my bad. I didn't get why we've inputted both It's a tough decision on how to solve this long-term I think:
Maybe we should discuss this quickly offline
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
we prepared both |
||||
| decoder_input_ids = self._shift_right_t5(labels) | ||||
| else: | ||||
| decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id) | ||||
|
|
||||
| batch = { | ||||
| "input_ids": input_ids, | ||||
| "attention_mask": attention_mask, | ||||
| "decoder_input_ids": decoder_input_ids, | ||||
| "labels": labels, | ||||
| } | ||||
| return batch | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the switch to fast tokenizers, this was actually not working.