adding search.PrefixConstrainedBeamSearch#2646
adding search.PrefixConstrainedBeamSearch#2646nicola-decao wants to merge 5 commits intofacebookresearch:masterfrom nicola-decao:add_PrefixConstrainedBeamSearch
Conversation
|
The test failed on something that is not part of the pull request |
myleott
left a comment
There was a problem hiding this comment.
You can ignore the test_translation_multi_simple_epoch test failure (the psutil import failure has been fixed in trunk).
But the test_ensemble_sequence_generator (tests.test_sequence_generator.TestJitSequeneceGenerator) failures seems related (see comment below)
| if num_remaining_sent == 0: | ||
| break | ||
| if isinstance(self.search, search.PrefixConstrainedBeamSearch) and step >= max_len: | ||
| if self.search.stop_on_max_len and step >= max_len: |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@myleott has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@myleott has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
I am not a Facebook employee so I cannot see the warnings and why this fails. |
I'm taking care of this :) |
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? This adds a new decoding strategy `search.PrefixConstrainedBeamSearch` that limits the vocabulary of the next token generation given a prefix (that is the previously generated tokens during beam search). An end user has just to give the optional argument `prefix_allowed_tokens_fn` to `.generate` or `.sample` to activate `PrefixConstrainedBeamSearch`. `prefix_allowed_tokens_fn(batch_id, tokens)` is a callback function that given the `batch_id` and `tokens` returns the list of allowed token for the next generation step. ## Did you have fun? YES! � Pull Request resolved: facebookresearch/fairseq#2646 Reviewed By: fabiopetroni Differential Revision: D24006805 Pulled By: myleott fbshipit-source-id: 40b1a866c6ea9f936272db27e2a020b18dbf8164
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? This adds a new decoding strategy `search.PrefixConstrainedBeamSearch` that limits the vocabulary of the next token generation given a prefix (that is the previously generated tokens during beam search). An end user has just to give the optional argument `prefix_allowed_tokens_fn` to `.generate` or `.sample` to activate `PrefixConstrainedBeamSearch`. `prefix_allowed_tokens_fn(batch_id, tokens)` is a callback function that given the `batch_id` and `tokens` returns the list of allowed token for the next generation step. ## Did you have fun? YES! � Pull Request resolved: facebookresearch/fairseq#2646 Reviewed By: fabiopetroni Differential Revision: D24006805 Pulled By: myleott fbshipit-source-id: 40b1a866c6ea9f936272db27e2a020b18dbf8164
Before submitting
What does this PR do?
This adds a new decoding strategy
search.PrefixConstrainedBeamSearchthat limits the vocabulary of the next token generation given a prefix (that is the previously generated tokens during beam search). An end user has just to give the optional argumentprefix_allowed_tokens_fnto.generateor.sampleto activatePrefixConstrainedBeamSearch.prefix_allowed_tokens_fn(batch_id, tokens)is a callback function that given thebatch_idandtokensreturns the list of allowed token for the next generation step.Did you have fun?
YES! 🙃