Skip to content

Add RoBERTa question answering & Update SQuAD runner to support RoBERTa#1386

Closed
stevezheng23 wants to merge 12 commits intohuggingface:masterfrom
stevezheng23:dev/zheng/roberta
Closed

Add RoBERTa question answering & Update SQuAD runner to support RoBERTa#1386
stevezheng23 wants to merge 12 commits intohuggingface:masterfrom
stevezheng23:dev/zheng/roberta

Conversation

@stevezheng23
Copy link
Copy Markdown

No description provided.

@stevezheng23
Copy link
Copy Markdown
Author

stevezheng23 commented Sep 30, 2019

@thomwolf / @LysandreJik / @VictorSanh / @julien-c Could you help review this PR? Thanks!

@codecov-io
Copy link
Copy Markdown

codecov-io commented Oct 2, 2019

Codecov Report

Merging #1386 into master will decrease coverage by 0.15%.
The diff coverage is 21.21%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1386      +/-   ##
==========================================
- Coverage   86.16%   86.01%   -0.16%     
==========================================
  Files          91       91              
  Lines       13593    13626      +33     
==========================================
+ Hits        11713    11720       +7     
- Misses       1880     1906      +26
Impacted Files Coverage Δ
transformers/modeling_roberta.py 69.18% <21.21%> (-11.39%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update be916cb...ee83f98. Read the comment docs.

@stevezheng23
Copy link
Copy Markdown
Author

Hi @thomwolf / @LysandreJik / @VictorSanh / @julien-c

I have also run experiments using RoBERT large setting in original paper and reproduced their results,

  • SQuAD v1.1
    {
    "exact": 88.25922421948913,
    "f1": 94.43790487416292,
    "total": 10570,
    "HasAns_exact": 88.25922421948913,
    "HasAns_f1": 94.43790487416292,
    "HasAns_total": 10570
    }
  • SQuAD v2.0
    {
    "exact": 86.05238777057188,
    "f1": 88.99602665148535,
    "total": 11873,
    "HasAns_exact": 83.38394062078272,
    "HasAns_f1": 89.27965999208608,
    "HasAns_total": 5928,
    "NoAns_exact": 88.71320437342304,
    "NoAns_f1": 88.71320437342304,
    "NoAns_total": 5945,
    "best_exact": 86.5914259243662,
    "best_exact_thresh": -2.146007537841797,
    "best_f1": 89.43104312625539,
    "best_f1_thresh": -2.146007537841797
    }

@julien-c
Copy link
Copy Markdown
Member

julien-c commented Oct 3, 2019

Awesome @stevezheng23. Can I push on top of your PR to change a few things before we merge?

(We refactored the tokenizer to handle the encoding of sequence pairs, including special tokens. So we don't need to do it inside each example script anymore)

@stevezheng23
Copy link
Copy Markdown
Author

@julien-c sure, please add changes in this PR if needed 👍

@stevezheng23
Copy link
Copy Markdown
Author

stevezheng23 commented Oct 3, 2019

@julien-c I've also upload the roberta large model finetuned on squad v2.0 data together with its prediction & evaluation results to public cloud storage https://storage.googleapis.com/mrc_data/squad/roberta.large.squad.v2.zip

@stevezheng23 stevezheng23 changed the title add roberta qa support & update squad runner/util Add RoBERTa question answering & Update SQuAD runner to support RoBERTa Oct 3, 2019
@julien-c
Copy link
Copy Markdown
Member

julien-c commented Oct 3, 2019

Can you check my latest commit @stevezheng23? Main change is that I removed the add_prefix_space for RoBERTa (which the RoBERTa authors don't use, as far as I know) which doesn't seem to make a significant difference.

@thomwolf @LysandreJik this is ready for review.

@julien-c julien-c mentioned this pull request Oct 3, 2019
@stevezheng23
Copy link
Copy Markdown
Author

Everything looks good.

As for the add_prefix_space flag,

  • For add_prefix_space=True, I have run the experiment, the F1 score is around 89.4
  • For add_prefix_space=False, I have also run the experiment, the F1 score is around 88.2

@LysandreJik
Copy link
Copy Markdown
Member

Great! Good job on reimplementing the cross-entropy loss when start/end positions are given.


ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add DistilBertConfig here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

query_tokens = tokenizer.tokenize(example.question_text, add_prefix_space=True)
else:
query_tokens = tokenizer.tokenize(example.question_text)
query_tokens = tokenizer.tokenize(example.question_text)
Copy link
Copy Markdown
Contributor

@erenup erenup Oct 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also observed an improvement with add_prefix_space=True when I used roberta

@thomwolf
Copy link
Copy Markdown
Member

thomwolf commented Oct 9, 2019

Look good to me.
We'll probably be able to simplify utils_squad a lot soon but that will be fine for now.
Do you want to add your experimental results with RoBERTa in examples/readme, with a recommendation to use add_prefix_space=True (fyi it's the opposite for NER)?

@thomwolf
Copy link
Copy Markdown
Member

thomwolf commented Oct 9, 2019

@julien-c do you want to add the roberta model finetuned on squad by @stevezheng23 in our library?

@julien-c
Copy link
Copy Markdown
Member

julien-c commented Oct 9, 2019

Yep @thomwolf

@stevezheng23
Copy link
Copy Markdown
Author

@thomwolf I have updated README file as you suggested, you can merge this PR when you think it's good to go. BTW, it seems CI build is broken

@thomwolf
Copy link
Copy Markdown
Member

Ok thanks, I'll let @julien-c finish to handle this PR when he's back.

@pminervini
Copy link
Copy Markdown
Contributor

pminervini commented Oct 16, 2019

@julien-c I've also upload the roberta large model finetuned on squad v2.0 data together with its prediction & evaluation results to public cloud storage https://storage.googleapis.com/mrc_data/squad/roberta.large.squad.v2.zip

Hey @stevezheng23 !

I just tried to reproduce your model with slightly different hyperparameters (batch_size=2 and gradient_accumulation=6 instead of batch_size=12), and I am currently getting worse results.

Results with your model:

{
  "exact": 86.05238777057188,
  "f1": 88.99602665148535,
  "total": 11873,
  "HasAns_exact": 83.38394062078272,
  "HasAns_f1": 89.27965999208608,
  "HasAns_total": 5928,
  "NoAns_exact": 88.71320437342304,
  "NoAns_f1": 88.71320437342304,
  "NoAns_total": 5945
}

Results with the model I trained, on the best checkpoint I was able to obtain after training for 8 epochs:

{
  "exact": 82.85184873241809,
  "f1": 85.85477834702593,
  "total": 11873,
  "HasAns_exact": 77.80026990553306,
  "HasAns_f1": 83.8147407750069,
  "HasAns_total": 5928,
  "NoAns_exact": 87.88898233809924,
  "NoAns_f1": 87.88898233809924,
  "NoAns_total": 5945
}

Your hyperparameters:

Namespace(adam_epsilon=1e-08, cache_dir='', config_name='', device=device(type='cuda', index=0), do_eval=True, do_lower_case=False, do_train=True, doc_stride=128, eval_all_checkpoints=False, evaluate_during_training=False, fp16=False, fp16_opt_level='O1', gradient_accumulation_steps=1, learning_rate=1.5e-05, local_rank=0, logging_steps=50, max_answer_length=30, max_grad_norm=1.0, max_query_length=64, max_seq_length=512, max_steps=-1, model_name_or_path='roberta-large', model_type='roberta', n_best_size=20, n_gpu=1, no_cuda=False, null_score_diff_threshold=0.0, num_train_epochs=2.0, output_dir='output/squad/v2.0/roberta.large', overwrite_cache=False, overwrite_output_dir=False, per_gpu_eval_batch_size=12, per_gpu_train_batch_size=12, predict_file='data/squad/v2.0/dev-v2.0.json', save_steps=500, seed=42, server_ip='', server_port='', tokenizer_name='', train_batch_size=12, train_file='data/squad/v2.0/train-v2.0.json', verbose_logging=False, version_2_with_negative=True, warmup_steps=500, weight_decay=0.01)

My hyperparameters:

Namespace(adam_epsilon=1e-08, cache_dir='', config_name='', device=device(type='cuda'), do_eval=True, do_lower_case=False, do_train=True, doc_stride=128, eval_all_checkpoints=False, evaluate_during_training=False, fp16=False, fp16_opt_level='O1', gradient_accumulation_steps=6, learning_rate=1.5e-05, local_rank=-1, logging_steps=50, max_answer_length=30, max_grad_norm=1.0, max_query_length=64, max_seq_length=512, max_steps=-1, model_name_or_path='roberta-large', model_type='roberta', n_best_size=20, n_gpu=1, no_cuda=False, null_score_diff_threshold=0.0, num_train_epochs=8.0, output_dir='../roberta.large.squad2.v1p', overwrite_cache=False, overwrite_output_dir=False, per_gpu_eval_batch_size=2, per_gpu_train_batch_size=2, predict_file='/home/testing/drive/invariance//workspace/data/squad/dev-v2.0.json', save_steps=500, seed=42, server_ip='', server_port='', tokenizer_name='', train_batch_size=2, train_file='/home/testing/drive/invariance//workspace/data/squad/train-v2.0.json', verbose_logging=False, version_2_with_negative=True, warmup_steps=500, weight_decay=0.01)

Do you have any ideas why this is happening ?

One thing that may be happening is that, when using max_grad_norm and gradient_accumulation=n, the clipping of the gradient norm seems to be done n times rather than just 1, but I need to look deeper into this.

I'd like to see what happens without the need of gradient accumulation - anyone with a spare TPU to share? 😬

@stevezheng23
Copy link
Copy Markdown
Author

stevezheng23 commented Oct 16, 2019

Ok thanks, I'll let @julien-c finish to handle this PR when he's back.

thanks, @thomwolf

@stevezheng23
Copy link
Copy Markdown
Author

@pminervini I haven't tried out using max_grad_norm and gradient_accumulation=n combination before. One thing you could pay attention to is that the checkpoint is trained with add_prefix_space=True for RoBERTa tokenizer.

@pminervini
Copy link
Copy Markdown
Contributor

pminervini commented Oct 16, 2019

@stevezheng23 if you look at it, the max_grad_norm is performed on all the gradients in the accumulation - I think it should be done just before the optimizer.step() call.

https://github.com/huggingface/transformers/blob/master/examples/run_squad.py#L163

@thomwolf what do you think ? should I go and do a PR ?

c0derm4n added a commit to c0derm4n/transformers that referenced this pull request Oct 31, 2019
huggingface#1386

 Open	stevezheng23 wants to merge 12 commits into huggingface:master from stevezheng23:dev/zheng/roberta
XLNetForQuestionAnswering,
XLNetTokenizer,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
from transformers import (WEIGHTS_NAME, BertConfig, BertForQuestionAnswering, BertTokenizer,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add RoBERTa to the title - (Finetuning the library models for question-answering...)


#### Fine-tuning RoBERTa on SQuAD

This is an example using 4-GPUs distributed training to fine-tune RoBERTa-large model on the SQuAD v2.0 dataset:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the model of the GPU so it will be easy to tune the per_gpu_eval_batch_size relatively to the memsize of the GPU used in the example

@julien-c
Copy link
Copy Markdown
Member

@LysandreJik just significantly rewrote our SQuAD integration in #1984 so we were holding out on merging this.

Does anyone here want to revisit this PR with the changes from #1984? Otherwise, we'll do it, time permitting.

@erenup
Copy link
Copy Markdown
Contributor

erenup commented Dec 11, 2019

cool, I'm willing to revisit it. I will take a look at your changes and tansformers' recent updates today (have been away from the Master branch for some time😊).

@ethanjperez
Copy link
Copy Markdown
Contributor

ethanjperez commented Dec 11, 2019

@julien-c I've also upload the roberta large model finetuned on squad v2.0 data together with its prediction & evaluation results to public cloud storage https://storage.googleapis.com/mrc_data/squad/roberta.large.squad.v2.zip

Hey @stevezheng23 !

I just tried to reproduce your model with slightly different hyperparameters (batch_size=2 and gradient_accumulation=6 instead of batch_size=12), and I am currently getting worse results.
Your hyperparameters:

Namespace(adam_epsilon=1e-08, cache_dir='', config_name='', device=device(type='cuda', index=0), do_eval=True, do_lower_case=False, do_train=True, doc_stride=128, eval_all_checkpoints=False, evaluate_during_training=False, fp16=False, fp16_opt_level='O1', gradient_accumulation_steps=1, learning_rate=1.5e-05, local_rank=0, logging_steps=50, max_answer_length=30, max_grad_norm=1.0, max_query_length=64, max_seq_length=512, max_steps=-1, model_name_or_path='roberta-large', model_type='roberta', n_best_size=20, n_gpu=1, no_cuda=False, null_score_diff_threshold=0.0, num_train_epochs=2.0, output_dir='output/squad/v2.0/roberta.large', overwrite_cache=False, overwrite_output_dir=False, per_gpu_eval_batch_size=12, per_gpu_train_batch_size=12, predict_file='data/squad/v2.0/dev-v2.0.json', save_steps=500, seed=42, server_ip='', server_port='', tokenizer_name='', train_batch_size=12, train_file='data/squad/v2.0/train-v2.0.json', verbose_logging=False, version_2_with_negative=True, warmup_steps=500, weight_decay=0.01)

My hyperparameters:

Namespace(adam_epsilon=1e-08, cache_dir='', config_name='', device=device(type='cuda'), do_eval=True, do_lower_case=False, do_train=True, doc_stride=128, eval_all_checkpoints=False, evaluate_during_training=False, fp16=False, fp16_opt_level='O1', gradient_accumulation_steps=6, learning_rate=1.5e-05, local_rank=-1, logging_steps=50, max_answer_length=30, max_grad_norm=1.0, max_query_length=64, max_seq_length=512, max_steps=-1, model_name_or_path='roberta-large', model_type='roberta', n_best_size=20, n_gpu=1, no_cuda=False, null_score_diff_threshold=0.0, num_train_epochs=8.0, output_dir='../roberta.large.squad2.v1p', overwrite_cache=False, overwrite_output_dir=False, per_gpu_eval_batch_size=2, per_gpu_train_batch_size=2, predict_file='/home/testing/drive/invariance//workspace/data/squad/dev-v2.0.json', save_steps=500, seed=42, server_ip='', server_port='', tokenizer_name='', train_batch_size=2, train_file='/home/testing/drive/invariance//workspace/data/squad/train-v2.0.json', verbose_logging=False, version_2_with_negative=True, warmup_steps=500, weight_decay=0.01)

Do you have any ideas why this is happening ?

You're using num_train_epochs=8 instead of 2, which makes the learning rate decay more slowly. Maybe that is causing the difference?

@ethanjperez
Copy link
Copy Markdown
Contributor

ethanjperez commented Dec 11, 2019

Regarding max_grad_norm - RoBERTa doesn't use gradient clipping, so the max_grad_norm changes aren't strictly necessary here

RoBERTa also uses adam_epsilon=1e-06 as I understand, but I'm not sure if it would change the results here

@erenup erenup mentioned this pull request Dec 14, 2019
@erenup
Copy link
Copy Markdown
Contributor

erenup commented Dec 14, 2019

Hi @stevezheng23 @julien-c @thomwolf @ethanjperez , I updated the run squad with roberta in #2173
based on #1984 and #1386. Could you please help to review it? Thank you very much.

@julien-c
Copy link
Copy Markdown
Member

Closed in favor of #2173 which should be merged soon.

@julien-c julien-c closed this Dec 20, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants