Add accelerate support for LongT5 models#20341
Add accelerate support for LongT5 models#20341sgugger merged 4 commits intohuggingface:mainfrom pszemraj:long-t5-accelerate
accelerate support for LongT5 models#20341Conversation
|
cc @KMFODA for inputs on tests & more 🤞 |
|
The documentation is not available anymore as the PR was closed or merged. |
Signed-off-by: peter szemraj <peterszemraj@gmail.com>
younesbelkada
left a comment
There was a problem hiding this comment.
Very cool PR! Glad to see that 8-bit integration is gaining interest and attention on more models 🔥
Just a small typo on the Google Colab: the .cuda() is not needed after instantiating the model with load_in_8bit and device_map=auto, so I would advice to remove it ;)
Can you make sure the slow tests pass with the command RUN_SLOW=1 pytest tests/models/longt5/test_modeling_longt5.py ? (You will need to have access to a GPU instance) - When I ran your fix, accelerate tests were failing. You can fix them by adding the lines here as it was done for BART / NLLB in #19912
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
|
Thanks for the feedback & good catch on the Colab! I've updated the notebook - will run and resolve the slow tests/accelerate items later today/tomorrow and revert back 👌 |
|
Hey @pszemraj ! |
|
Hi @pszemraj ! |
|
Hey! let me give it a stab today (I was sick for a week) if you don't see anything by tomorrow, feel free to take it home!
Email | ***@***.***
On 12/6/2022 8:54:39 AM, Younes Belkada ***@***.***> wrote:
Hi @pszemraj [https://github.com/pszemraj] !
Is it ok if I try to take over the PR? this addition could be very nice to the lib! Let me know what do you think :)
—
Reply to this email directly, view it on GitHub [#20341 (comment)], or unsubscribe [https://github.com/notifications/unsubscribe-auth/AR3GSMFN4MP444ZC72B4EN3WL3WL7ANCNFSM6AAAAAASGEAOLE].
You are receiving this because you were mentioned.Message ID: ***@***.***>
[31e14b4b-28c3-4714-8081-803278962750]
|
|
@younesbelkada hey - was trying to get the tests to pass and evaluate further but unfortunately the machine I do have access to a GPU on and can work this was running into some install issues with the If you're willing to finish this, that would probably be easiest 😅 I'll add the line for accelerate as you suggested and rebase as per the contrib guidelines, feel free to take whatever you find useful :) |
|
Thanks a lot @pszemraj for your great efforts, will have a look ASAP ;) this is definitely in my TODO list |
|
thanks so much! I see you pushed so I will leave you to it (but feel free to let me know if questions or you need me to change anything on my end) then we can get this bad boi usable on free Colab runtimes :) |
|
Thanks for taking it home @younesbelkada! and thanks for the review @sgugger. Happy to help :) |
* ✨ add accelerate support for LongT5 models Signed-off-by: peter szemraj <peterszemraj@gmail.com> * fix `accelerate` tests * Trigger CI test Signed-off-by: peter szemraj <peterszemraj@gmail.com> Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Signed-off-by: peter szemraj peterszemraj@gmail.com
What does this PR do?
This PR adds
acceleratesupport for the longT5 models (i.e., make it possible to usedevice_map="auto"), so these models can be loaded in 8bit using load_in_8bit=True.This helps enable inference with trained/fine-tuned SoTA long summarization models using limited memory☺️
Took inspiration from reviewing similar PRs for other models: #19912 and #19927
cc @sgugger
test results
I made a Colab notebook that clones the branch from my fork to demo the
load_in_8bit=Trueworking. Everything else is the same for comparison purposes (except the function that says the model size) as the fp32/standard notebook listed on my fine-tuned model card.I also ran the tests for
longT5locally:$ python -m pytest -n auto --dist=loadfile -s -v tests/models/longt5/test_modeling_longt5.py ( ... many things here ...) =================================================== 196 passed, 58 skipped, 118 warnings in 30.49s ===================================================