diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index dcbfe9163946..73123063d07d 100755 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -26,6 +26,7 @@ AutoTokenizer, HfArgumentParser, MBartTokenizer, + MBartTokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments, set_seed, @@ -220,11 +221,14 @@ def main(): data_args.eval_beams = model.config.num_beams # set decoder_start_token_id for MBart - if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): assert ( data_args.tgt_lang is not None and data_args.src_lang is not None ), "mBart requires --tgt_lang and --src_lang" - model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.tgt_lang) if model_args.freeze_embeds: freeze_embeds(model) @@ -284,7 +288,9 @@ def main(): args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, - data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), + data_collator=Seq2SeqDataCollator( + tokenizer, data_args, model.config.decoder_start_token_id, training_args.tpu_num_cores + ), compute_metrics=compute_metrics_fn, tokenizer=tokenizer, ) diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 9df658f0218b..8b24bfdadcf6 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -33,8 +33,9 @@ from torch.utils.data import Dataset, Sampler from sentence_splitter import add_newline_to_end_of_each_sentence -from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer +from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer from transformers.file_utils import cached_property +from transformers.models.bart.modeling_bart import shift_tokens_right try: @@ -274,9 +275,10 @@ def collate_fn(self, batch) -> Dict[str, torch.Tensor]: class Seq2SeqDataCollator: - def __init__(self, tokenizer, data_args, tpu_num_cores=None): + def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None): self.tokenizer = tokenizer self.pad_token_id = tokenizer.pad_token_id + self.decoder_start_token_id = decoder_start_token_id assert ( self.pad_token_id is not None ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." @@ -304,9 +306,15 @@ def __call__(self, batch) -> Dict[str, torch.Tensor]: labels = trim_batch(labels, self.pad_token_id) input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) + if isinstance(self.tokenizer, T5Tokenizer): + decoder_input_ids = self._shift_right_t5(labels) + else: + decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id) + batch = { "input_ids": input_ids, "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, "labels": labels, } return batch diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c5e745577e75..8a3d7a43f533 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1294,14 +1294,18 @@ def compute_loss(self, model, inputs): Subclass and override for custom behavior. """ + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None outputs = model(**inputs) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] - if self.label_smoother is not None and "labels" in inputs: - return self.label_smoother(outputs, inputs["labels"]) + if labels is not None: + return self.label_smoother(outputs, labels) else: # We don't use .loss here since the model may return tuples instead of ModelOutput. return outputs["loss"] if isinstance(outputs, dict) else outputs[0] diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 850f8f841538..db7a080082e9 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -380,17 +380,26 @@ class LabelSmoother: ignore_index: int = -100 def __call__(self, model_output, labels): - model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0] - logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1] + logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0] log_probs = -torch.nn.functional.log_softmax(logits, dim=-1) + if labels.dim() == log_probs.dim() - 1: + labels = labels.unsqueeze(-1) - # Look at the ignored index and mask the corresponding log_probs. - padding_mask = labels.unsqueeze(-1).eq(self.ignore_index) - log_probs.masked_fill_(padding_mask, 0.0) + padding_mask = labels.eq(self.ignore_index) + # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask + # will ignore them in any case. + labels.clamp_min_(0) + nll_loss = log_probs.gather(dim=-1, index=labels) + smoothed_loss = log_probs.sum(dim=-1, keepdim=True) + + nll_loss.masked_fill_(padding_mask, 0.0) + smoothed_loss.masked_fill_(padding_mask, 0.0) # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded): - smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum()) - return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss + num_active_elements = padding_mask.numel() - padding_mask.long().sum() + nll_loss = nll_loss.sum() / num_active_elements + smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1]) + return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None): diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index f375ca5367e2..19dfa9b1d194 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -71,7 +71,7 @@ def test_label_smoothing(self): random_logits = torch.randn(4, 5, num_labels) random_labels = torch.randint(0, num_labels, (4, 5)) loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1)) - model_output = SequenceClassifierOutput(loss=loss, logits=random_logits) + model_output = SequenceClassifierOutput(logits=random_logits) label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels) log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1) expected_loss = (1 - epsilon) * loss + epsilon * log_probs.mean() @@ -83,7 +83,7 @@ def test_label_smoothing(self): random_labels[2, 3] = -100 loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1)) - model_output = SequenceClassifierOutput(loss=loss, logits=random_logits) + model_output = SequenceClassifierOutput(logits=random_logits) label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels) log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1) # Mask the log probs with the -100 labels