From 9653560f805a7da1907f3f7f15d1fda8c8083a63 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 20 Jan 2021 19:06:19 -0500 Subject: [PATCH 1/4] Fix memory regression in Seq2Seq example --- examples/seq2seq/finetune_trainer.py | 4 +++- examples/seq2seq/utils.py | 12 ++++++++++-- src/transformers/trainer.py | 8 ++++++-- src/transformers/trainer_pt_utils.py | 20 +++++++++++++------- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index dcbfe9163946..f1520984f79f 100755 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -284,7 +284,9 @@ def main(): args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, - data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), + data_collator=Seq2SeqDataCollator( + tokenizer, data_args, model.config.decoder_start_token_id, training_args.tpu_num_cores + ), compute_metrics=compute_metrics_fn, tokenizer=tokenizer, ) diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 9df658f0218b..8b24bfdadcf6 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -33,8 +33,9 @@ from torch.utils.data import Dataset, Sampler from sentence_splitter import add_newline_to_end_of_each_sentence -from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer +from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer from transformers.file_utils import cached_property +from transformers.models.bart.modeling_bart import shift_tokens_right try: @@ -274,9 +275,10 @@ def collate_fn(self, batch) -> Dict[str, torch.Tensor]: class Seq2SeqDataCollator: - def __init__(self, tokenizer, data_args, tpu_num_cores=None): + def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None): self.tokenizer = tokenizer self.pad_token_id = tokenizer.pad_token_id + self.decoder_start_token_id = decoder_start_token_id assert ( self.pad_token_id is not None ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." @@ -304,9 +306,15 @@ def __call__(self, batch) -> Dict[str, torch.Tensor]: labels = trim_batch(labels, self.pad_token_id) input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) + if isinstance(self.tokenizer, T5Tokenizer): + decoder_input_ids = self._shift_right_t5(labels) + else: + decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id) + batch = { "input_ids": input_ids, "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, "labels": labels, } return batch diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c5e745577e75..8a3d7a43f533 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1294,14 +1294,18 @@ def compute_loss(self, model, inputs): Subclass and override for custom behavior. """ + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None outputs = model(**inputs) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] - if self.label_smoother is not None and "labels" in inputs: - return self.label_smoother(outputs, inputs["labels"]) + if labels is not None: + return self.label_smoother(outputs, labels) else: # We don't use .loss here since the model may return tuples instead of ModelOutput. return outputs["loss"] if isinstance(outputs, dict) else outputs[0] diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 850f8f841538..2eecc8c6d45a 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -380,17 +380,23 @@ class LabelSmoother: ignore_index: int = -100 def __call__(self, model_output, labels): - model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0] - logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1] + logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0] log_probs = -torch.nn.functional.log_softmax(logits, dim=-1) + if labels.dim() == log_probs.dim() - 1: + labels = labels.unsqueeze(-1) - # Look at the ignored index and mask the corresponding log_probs. - padding_mask = labels.unsqueeze(-1).eq(self.ignore_index) - log_probs.masked_fill_(padding_mask, 0.0) + nll_loss = log_probs.gather(dim=-1, index=labels) + smoothed_loss = log_probs.sum(dim=-1, keepdim=True) + + padding_mask = labels.eq(self.ignore_index) + nll_loss.masked_fill_(padding_mask, 0.0) + smoothed_loss.masked_fill_(padding_mask, 0.0) # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded): - smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum()) - return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss + num_active_elements = padding_mask.numel() - padding_mask.long().sum() + nll_loss = nll_loss.sum() / num_active_elements + smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1]) + return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None): From 630f1cecff15831791a35683d45c4b9587577914 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 20 Jan 2021 20:32:37 -0500 Subject: [PATCH 2/4] Fix test and properly deal with -100 --- src/transformers/trainer_pt_utils.py | 5 ++++- tests/test_trainer_utils.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 2eecc8c6d45a..fd7e9bd10155 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -385,10 +385,13 @@ def __call__(self, model_output, labels): if labels.dim() == log_probs.dim() - 1: labels = labels.unsqueeze(-1) + padding_mask = labels.eq(self.ignore_index) + # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask + # will ignore them in any case. + labels = torch.where(padding_mask, torch.tensor(0), labels) nll_loss = log_probs.gather(dim=-1, index=labels) smoothed_loss = log_probs.sum(dim=-1, keepdim=True) - padding_mask = labels.eq(self.ignore_index) nll_loss.masked_fill_(padding_mask, 0.0) smoothed_loss.masked_fill_(padding_mask, 0.0) diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index f375ca5367e2..19dfa9b1d194 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -71,7 +71,7 @@ def test_label_smoothing(self): random_logits = torch.randn(4, 5, num_labels) random_labels = torch.randint(0, num_labels, (4, 5)) loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1)) - model_output = SequenceClassifierOutput(loss=loss, logits=random_logits) + model_output = SequenceClassifierOutput(logits=random_logits) label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels) log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1) expected_loss = (1 - epsilon) * loss + epsilon * log_probs.mean() @@ -83,7 +83,7 @@ def test_label_smoothing(self): random_labels[2, 3] = -100 loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1)) - model_output = SequenceClassifierOutput(loss=loss, logits=random_logits) + model_output = SequenceClassifierOutput(logits=random_logits) label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels) log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1) # Mask the log probs with the -100 labels From 67f5f284e955c089b801ff2d1d2fa6ea5c09d1db Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 20 Jan 2021 20:37:11 -0500 Subject: [PATCH 3/4] Easier condition with device safety --- src/transformers/trainer_pt_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index fd7e9bd10155..db7a080082e9 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -388,7 +388,7 @@ def __call__(self, model_output, labels): padding_mask = labels.eq(self.ignore_index) # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask # will ignore them in any case. - labels = torch.where(padding_mask, torch.tensor(0), labels) + labels.clamp_min_(0) nll_loss = log_probs.gather(dim=-1, index=labels) smoothed_loss = log_probs.sum(dim=-1, keepdim=True) From e467b131e81296fc805c86b149122356aaa906f2 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 20 Jan 2021 21:26:08 -0500 Subject: [PATCH 4/4] Patch for MBartTokenzierFast --- examples/seq2seq/finetune_trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index f1520984f79f..73123063d07d 100755 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -26,6 +26,7 @@ AutoTokenizer, HfArgumentParser, MBartTokenizer, + MBartTokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments, set_seed, @@ -220,11 +221,14 @@ def main(): data_args.eval_beams = model.config.num_beams # set decoder_start_token_id for MBart - if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): assert ( data_args.tgt_lang is not None and data_args.src_lang is not None ), "mBart requires --tgt_lang and --src_lang" - model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.tgt_lang) if model_args.freeze_embeds: freeze_embeds(model)