Add Flash Attention 2 to Persimmon#27685
Conversation
younesbelkada
left a comment
There was a problem hiding this comment.
Looks very nice thanks a lot !
Some changes in the PR seem unrelated (e.g. changes on Phi, etc) I think that you need to install ruff==0.1.5 and run make style again
I'll also run the benchmarks later with FA2-Phi and update in this PR !
There are also some strange failing CI, can you try to rebase with main again?
1222223 to
5590ead
Compare
|
Re-based, installed |
xhluca
left a comment
There was a problem hiding this comment.
I left some comments regarding the target_dtype inference
| if hasattr(self.config, "_pre_quantization_dtype"): | ||
| target_dtype = self.config._pre_quantization_dtype | ||
| else: | ||
| target_dtype = self.q_proj.weight.dtype |
There was a problem hiding this comment.
This line will give you an error because self.q_proj was never defined here (it is defined in Llama's __init__, which is why it worked). I am not sure exactly what this is trying to achieve, but you might try some other module that is defined in the __init__ of the PersimmonAttention class.
There was a problem hiding this comment.
Yes, you should use the self.query_key_value here.
|
cc @molbap as younes is offline |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for the update! Let's make sure to rebase on main and only include changes for persimmon!
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Integrates FA2 to Persimmon per #26350, #27052 (former branch was messed up after trying to rebase, so PR'ing a new branch).
Before submitting
Pull Request section?
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@younesbelkada @ArthurZucker
Notes
rebaseonmain.FA-2] Add Flash Attention toPhi#27661 forgenerate_padding_righttest. However,Persimmontokenizer configs do not have eithereosorpadtokens (both are set tonullsee here), so simply copying theLlamaModelTestgenerate_padding_righttest override does not work.dummy inputson the full pretrained model for thegenerate_padding_righttest, no luck either -- this is left as the current implementation intest_persimmon_modeling.py.generate_padding_testfor other models for FA2 -- see comments.generate_padding_righttest asskipfor now.persimmonwere changed in this PR due to fixes from runningmake {quality, style, fixup}