Update past_key_values in GPT-2#9596
Conversation
|
CircleCI error messages says as below. In In |
|
Is there a difference between I first thought it might be a difference between the Causal language model and the Seq2Seq language model, but it seems that both And as for the contents of transformers/src/transformers/models/bart/modeling_bart.py Lines 1236 to 1244 in 236cc36 |
|
I've updated transformers/src/transformers/models/xlnet/modeling_xlnet.py Lines 581 to 607 in 236cc36 It seems |
|
Hey @forest1988, You're PR looks very nice! Yes, it is expected that
def _reorder_cache(self, past, beam_idx):
raise NotImplementedError(...) |
|
I've just updated |
|
This way it's much cleaner and correct :-) The reason I'm proposing this change is that the def _reorder_cache(self, past, beam_idx):
raise NotImplementedError(f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}") |
|
I think this should solve the problems, let me know if you need more help :-) |
|
Thank you for your advice! I'll update |
89ee453 to
d04b10c
Compare
|
Thanks to your kind advice, I could solve the problem of The last one remaining bug is: I think I should modify
|
|
All checks have passed! However, in the documentation of |
past_key_values in GPT-2past_key_values in GPT-2
| called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every | ||
| generation step. | ||
|
|
||
| For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in |
There was a problem hiding this comment.
remove those lines and past_key_values above
There was a problem hiding this comment.
I cleaned it as well.
patrickvonplaten
left a comment
There was a problem hiding this comment.
The PR looks very nice - thanks so much for taking the time to tackle this @forest1988 . Let's wait a bit to see how to proceed with gradient_checkpointing in GPT2 as this question will come up more often. IMO, use_cache should always be False for training so either we update all use_cache in the models with a use_cache= not self.is_training and (use_cache if use_cache is not None else self.config.use_cache) or we force it somehow in the Trainer. Similarly gradient_checkpointing should never be set to True when the model is not training IMO (we could also automatically disable this using self.training). Let's see what @LysandreJik and @sgugger think.
sgugger
left a comment
There was a problem hiding this comment.
This is not a part of the library I'm very familiar with, so the changes look okay on my side, but I'm no expert.
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
LysandreJik
left a comment
There was a problem hiding this comment.
These changes look good to me! Thanks for taking care of it @forest1988.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Great work @forest1988,
I hope it's fine for you that I went into the PR to do some final fixes. Thanks a lot for cleaning this up :-)
Of course! Thank you for adding fixes to make this PR more valuable! |
LysandreJik
left a comment
There was a problem hiding this comment.
Your commit looks good to me @patrickvonplaten! Thanks.
sgugger
left a comment
There was a problem hiding this comment.
The new changes look good to me, thanks!
|
Awesome, merging - great job @forest1988 ! |
|
Thank you for your advice and encouraging comments! |
|
|
||
| if use_cache is True: | ||
| present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking | ||
| present = (key.transpose(-2, -1), value) # transpose to have same shapes |
There was a problem hiding this comment.
This is the reason for the recent failure of the slow test:
RUN_SLOW=1 pytest tests/test_onnx.py::OnnxExportTestCase::test_export_pytorch
Can you fix the onnx part easily? @mfuntowicz @Narsil
What does this PR do?
It seems GPT-2 and BartDecoder has a different style of
past_key_values.Advised by @patrickvonplaten,
I opened this PR to change GPT-2's cache format from a single tensor to a tuple of 2 tensors.
Once this problem is solved, it is expected that
past_key_valuesin GPT-2 will be handled in the same way as in Bart.Sorry there remain some errors. This PR is [WIP].
I would appreciate your advice on how to update
generation_utils.py.Can I modify
_reorder_cacheso that past is replaced from Tuple[torch.Tensor] to Tuple[Tuple[torch.Tensor]],or should I consider other output variations, output.mem and outputs.past_buckets_states?
Fixes #9391
From patrickvonplaten:
This PR cleans the
_reorder_cachelogic. Now_reorcher_cachedefaults to an erroneousNotImplementedErroringeneration_utils.pyforcing the model to implement its corresponding_rerorder_cacheit themodeling_...pyfile itself. This is cleaner as_reorder_cachestrongly differs from model to model. In addition, this PR makes sure thatgradient_checkpointingcan only be used if the model is in training mode and makes sure thatuse_cacheis disabled when training andgradient_checkpointingis enabled to prevent errors.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
GPT2: @LysandreJik, @patrickvonplaten