BLOOM Flax#18022
Conversation
…to add_bloom_flax
…bloom_flax # Conflicts: # src/transformers/models/bloom/modeling_flax_bloom.py
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Almost there! Just some small refactoring suggestions to clean the code up a bit
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Thanks for addressing the previous comments @younesbelkada. Two things from this round of review:
- Could we tidy up the
build_alibi_tensor_flaxfunction to avoid a triple nested function? Much of the logic can be copied over from PyTorch BLOOM! - Big question for me is whether we keep
scanor not - I'm in favour of removing it for Flax BLOOM (see comments below)
| all_attentions = () if output_attentions else None | ||
| all_hidden_states = () if output_hidden_states else None | ||
|
|
||
| if self.use_scan: |
There was a problem hiding this comment.
We've currently left scan in the modelling code. Part of me thinks we should remove it for Transformers for the following reasons:
scanadds a lot of boilerplate code that isn't very easy to understand- It is only beneficial for compile times when the model size is large, mostly when training and less so for inference
In the latter case, users will also likely shard the model, meaning they can employ the standalone code in bloom-jax-inference where we can retain scan functionality.
We also found generation time to be slower when using scan vs not using it (despite a faster compile time). The generation time will amortise the compile time in any use case of Flax BLOOM.
Given that the philosophy of Transformers is functional, easy-to-understand code that is not necessarily fully optimised, I'm in favour of stripping scan from Flax BLOOM and leaving it to bloom-jax-inference to serve users that want to deploy larger variants of the model.
- remove unused code - refactor a bit - revert import `torch`
- change build alibi
There was a problem hiding this comment.
Should have addressed your new suggestions @sanchit-gandhi ! Here I mainly focused on refactoring a bit the build_alibi function to match the implementational style of Pytorch -> this way it seems to be more readable !
I will leave you and @patil-suraj and @patrickvonplaten decide regarding the scan feature and happy to remove it once we agree on that!
Can also confirm the slow tests/conversion tests pass ;)
d789a85 to
dcdd563
Compare
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Should we merge this one ? cc @patrickvonplaten @patil-suraj @sanchit-gandhi |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Both #17761 and this PR show a lot of work. Why isn't it merged? |
|
Adding this to my TODOs |
What does this PR do?
An attempt of adding Flax implementation of BLOOM - original PR from @haileyschoelkopf #17761
TODOs: