Generate: speculative decoding#27979
Conversation
18a4eda to
993c9ee
Compare
|
@patrickvonplaten tagging you here for a 2nd set of eyes on the speculative decoding method (changes in |
There was a problem hiding this comment.
These are not the best variable names, but it's hard to compare against the original algorithm if they don't match 🤔 As such, I've decided to keep the original names
There was a problem hiding this comment.
I'm fine with it as there's good comments and other variables are well names e.g. is_rejected :)
|
Thanks for adding this! Can we split this up into two separate PRs: one changing the assisted generation and the other adding speculative decoding? |
|
@amyeroberts pulled the assisted generation changes into this PR: #28030 After it is merged, I will rebase this one and ping you again -- this one will become exclusively about speculative decoding 🤗 |
7bf05a9 to
e234e1e
Compare
|
@amyeroberts I've rerun the slow tests, and I can confirm they are passing. Ready for a review :) |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for adding this!
Can we add some tests, in particular one which checks case 1. and one which makes sure the correct logic branch is being selected e.g. checking candidate_logits is None when expected (might be a test on the candidate generator instead)?
There was a problem hiding this comment.
I'm fine with it as there's good comments and other variables are well names e.g. is_rejected :)
| if do_sample: | ||
| probs = new_logits.softmax(dim=-1) | ||
| selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | ||
| else: | ||
| selected_tokens = new_logits.argmax(dim=-1) |
There was a problem hiding this comment.
| if do_sample: | |
| probs = new_logits.softmax(dim=-1) | |
| selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | |
| else: | |
| selected_tokens = new_logits.argmax(dim=-1) | |
| if do_sample: | |
| probs = new_logits.softmax(dim=-1) | |
| selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | |
| else: | |
| selected_tokens = new_logits.argmax(dim=-1) |
It's probably time to soon factor this out into something like:
selected_tokens = Categorical(new_logits / temperature).sample()everywhere in generate
There was a problem hiding this comment.
Yes! Then equivalent sampling/non-sampling methods (e.g. greedy decoding/samplinh) could be merged into a single function, facilitating maintenance. I'm going to leave it to a follow-up PR, though, to keep this PR exclusively about speculative decoding.
| else: | ||
| selected_tokens = new_logits.argmax(dim=-1) | ||
| if do_sample: | ||
| probs = new_logits.softmax(dim=-1) |
There was a problem hiding this comment.
is this case still relevant? Not sure it's a good idea to have two "assisted decoding" do_sample=True cases in our generate. Should we maybe just deprecate this case?
There was a problem hiding this comment.
Super cool addition!
Not really related to this PR, but I feel like we should start putting all the generation submethods (assisted decoding, greedy & sample (guess we can merge these two), beam search, ...) into their own files by now
My only important comment here is that I don't think it's great that we have 2 assisted generation cases now where do_sample=True. Can we deprecate the "non-official" one?
|
@patrickvonplaten the two types of sampling are needed :D New candidate-based methods are popping up (e.g. #27775), and they don't necessarily have logits. As such, speculative decoding, which needs the candidates' logits, can't be applied to those methods. |
But shouldn't they just be the "own" method now? I.e. I don't think we should put #27775 into the speculative decoding method no? |
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
|
@patrickvonplaten #27775 does not introduce changes to assisted generation 🤗 In #28030 I've abstracted the candidate generation part of assisted generation. We now load candidate generators the same way as we load the logits processors: transformers/src/transformers/generation/utils.py Lines 899 to 919 in e6dcf8a In assisted generation, we call the candidate generator to get candidate sequences (which may or may not contain associated logits, depending on the method) transformers/src/transformers/generation/utils.py Line 4588 in e6dcf8a The technique in #27775 can thus be added by adding a new candidate generator in Because needing the logits (for speculative decoding) is a very limiting constraint, I'd rather keep the two sampling paths. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@amyeroberts PR comments addressed 🤗 @patrickvonplaten Unless you don't strongly oppose, I'd like to keep the two sampling paths, for the reasons I've written here -- I think it will be beneficial in the long run! :) (otherwise, a whole new generation method has to be written for #27775) |
|
@amyeroberts -- @patrickvonplaten and I had a chat about whether to keep the two sampling paths or not. For context, here's what we agreed on:
|
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for iterating!
|
@gante |
|
@jmamou speculative decoding with |
|
@gante |
|
@gante |
|
@gante In current implementation (4.38), Is it intentional? If that's a bug, I can open a PR to fix it. |
Not sure if this is a good idea
This is a good point! A PR to revert to the previous behaviour (with a test) would be appreciated 🙏 |
What does this PR do?
Useful context:
In a recent PR (#27750), the candidate generation in assisted generation got abstracted, so we can host new candidate generation techniques (such as #27722).
This PR:
Reworks assisted candidate generation to callEdit: moved to Generate: assisted decoding now uses.generate(), instead of having its own custom generation loop. For most models this is nothing more than a nice abstraction. However, for models with a customgenerate()function, this means the assistant model will now make use of it! (🤔 does this mean that DistilWhisper gets better numbers with this refactor?)generatefor the assistant #28030The following tests were run locally and are passing:
RUN_SLOW=1 py.test tests/models/whisper/ -k speculativepy.test tests/ -k test_assisted(which now triggers speculative decoding)TODO: