Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions examples/seq2seq/finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AutoTokenizer,
HfArgumentParser,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
Expand Down Expand Up @@ -220,11 +221,14 @@ def main():
data_args.eval_beams = model.config.num_beams

# set decoder_start_token_id for MBart
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the switch to fast tokenizers, this was actually not working.

assert (
data_args.tgt_lang is not None and data_args.src_lang is not None
), "mBart requires --tgt_lang and --src_lang"
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.tgt_lang)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten Wondering if this the right way to do this for the fast tokenizer, as it has no lang_code_to_id.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should work, as MBartTokenizer is inherited from XLMRobertaTokenizer and in convert_tokens_to_ids it first checks if the token is in fairseq_tokens_to_ids dict and returns the id if it is. All lang tokens are added in the fairseq_tokens_to_ids dict

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj, I think @sgugger was wondering about the fast tokenizers, so MBartTokenizerFast, and those don't have the fairseq_tokens_to_ids dict. However, I think your approach is the correct one here @sgugger . MBartTokenizerFast adds all languages codes, called FAIRSEQ_LANGUAGE_CODES to the special tokens so that they can be converted just the way you did above -> LGTM!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah, I see. My bad.


if model_args.freeze_embeds:
freeze_embeds(model)
Expand Down Expand Up @@ -284,7 +288,9 @@ def main():
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
data_collator=Seq2SeqDataCollator(
tokenizer, data_args, model.config.decoder_start_token_id, training_args.tpu_num_cores
),
compute_metrics=compute_metrics_fn,
tokenizer=tokenizer,
)
Expand Down
12 changes: 10 additions & 2 deletions examples/seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
from torch.utils.data import Dataset, Sampler

from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
from transformers.file_utils import cached_property
from transformers.models.bart.modeling_bart import shift_tokens_right


try:
Expand Down Expand Up @@ -274,9 +275,10 @@ def collate_fn(self, batch) -> Dict[str, torch.Tensor]:


class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
self.decoder_start_token_id = decoder_start_token_id
assert (
self.pad_token_id is not None
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
Expand Down Expand Up @@ -304,9 +306,15 @@ def __call__(self, batch) -> Dict[str, torch.Tensor]:
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)

if isinstance(self.tokenizer, T5Tokenizer):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, yeah this was my bad. I didn't get why we've inputted both decoder_input_ids and labels previously - should have thought about it a bit more!

It's a tough decision on how to solve this long-term I think:

  • there is no guarantee that shift_tokens_right is the same for all models, we already have three different methods (MBart, EncoderDecoder (not yet implemented but will be different), Bart & other enc-dec)
  • We don't really want to start an if isinstance(self.tokenizer)-else cascade in the data collactor.
  • I don't really like to add a flag _do_compute_loss to the encoder-decoder models because it breaks consistency that only passing lables will trigger a loss calculation with all other models and adds more complexity to the user in that another config param has to be understood.
  • Another option would be to pass the models to the __init__ of the data collator and then to call: decoder_input_ids = model.shift_right(...) and forcing all functions to have the same name similar to prepare_inputs_for_generation(...)
  • I think overall, it would actually be better to get fully rid of the data collators all together and instead make use of datasets.map(...) to prepare the inputs and just use the default datacollator as shown in
    def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
    => then we wouldn't have any of those problems and would have more control over what to input to the model

Maybe we should discuss this quickly offline

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get why we've inputted both decoder_input_ids and labels previously

we prepared both decoder_input_ids and labels because we need to pass decoder_input_ids to model explicitly when doing label smoothing as we don't pass labels to avoid calculating the loss twice.

decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id)

batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,14 +1294,18 @@ def compute_loss(self, model, inputs):

Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]

if self.label_smoother is not None and "labels" in inputs:
return self.label_smoother(outputs, inputs["labels"])
if labels is not None:
return self.label_smoother(outputs, labels)
else:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
Expand Down
23 changes: 16 additions & 7 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,17 +380,26 @@ class LabelSmoother:
ignore_index: int = -100

def __call__(self, model_output, labels):
model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0]
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1]
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
if labels.dim() == log_probs.dim() - 1:
labels = labels.unsqueeze(-1)

# Look at the ignored index and mask the corresponding log_probs.
padding_mask = labels.unsqueeze(-1).eq(self.ignore_index)
log_probs.masked_fill_(padding_mask, 0.0)
padding_mask = labels.eq(self.ignore_index)
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
# will ignore them in any case.
labels.clamp_min_(0)
nll_loss = log_probs.gather(dim=-1, index=labels)
smoothed_loss = log_probs.sum(dim=-1, keepdim=True)

nll_loss.masked_fill_(padding_mask, 0.0)
smoothed_loss.masked_fill_(padding_mask, 0.0)

# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
nll_loss = nll_loss.sum() / num_active_elements
smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss


def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_label_smoothing(self):
random_logits = torch.randn(4, 5, num_labels)
random_labels = torch.randint(0, num_labels, (4, 5))
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
model_output = SequenceClassifierOutput(loss=loss, logits=random_logits)
model_output = SequenceClassifierOutput(logits=random_logits)
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.mean()
Expand All @@ -83,7 +83,7 @@ def test_label_smoothing(self):
random_labels[2, 3] = -100

loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
model_output = SequenceClassifierOutput(loss=loss, logits=random_logits)
model_output = SequenceClassifierOutput(logits=random_logits)
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
# Mask the log probs with the -100 labels
Expand Down