Generate: consistently handle special tokens as tensors#29788
Generate: consistently handle special tokens as tensors#29788gante wants to merge 16 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
a746766 to
7863c5d
Compare
There was a problem hiding this comment.
The logic of this function is now within _prepare_special_tokens, which preprocesses all special tokens
There was a problem hiding this comment.
ALL preprocessing logic for the special tokens now resides in this function 🧹
There was a problem hiding this comment.
kwargs_has_attention_mask is an optional argument so we can use this function in tests, to prepare special tokens.
There was a problem hiding this comment.
The decoding functions are backward compatible (for now), and we can still pass int/list(int) as special tokens.
The doctests in generate test this.
There was a problem hiding this comment.
Musicgen (and its melody variant) have their own custom generate, relying on this method.
I've intentionally not updated this custom generate, to pressure us into moving towards a single generate function.
There was a problem hiding this comment.
The logic rewritten in functions like this is torch.compile(..., fullgraph=True) compatible 😉
zucchini-nlp
left a comment
There was a problem hiding this comment.
Thanks for working on this 😄
Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz>
7537207 to
dd5bf8c
Compare
|
let's merge #29956 first, so the diff here becomes much smaller (the EOS-as-stopping-criteria made the diff more elaborate) (Arthur -- don't review this one until that is merged, I'll ping you again) |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
To enable
torch.compilewithgenerate, some special token-related operations have to be rewritten into torch operations. That requires special tokens to be tensors instead of integers or a list of integers. (See #29374 for a working prototype)This PR reworks special token usage in
generateto consistently treat them as a tensor, as opposed to e.g. keeping track ofeos_token_idin integer and in tensor form.👉 Review suggestion: start by reading
_prepare_special_tokensand how it fits ingenerate.Requirements before merging this PR:
Tests ran locally:
pytest --doctest-modules src/transformers/generation/logits_process.py -vv), needs requirement to be merged firstpytest --doctest-modules src/transformers/generation/utils.py -vv)RUN_SLOW=1 py.test tests/generation/test_utils.py -vv)RUN_SLOW=1 py.test tests/test_cache_utils.py -vv) -- same failures as inmainRUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv)RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -vv) -- same failures as inmain