Registers StaticCache serialization functions for torch.export.export#39931
Registers StaticCache serialization functions for torch.export.export#39931xadupre wants to merge 8 commits intohuggingface:mainfrom
Conversation
|
cc @yuanyao-nv |
|
@xadupre Is the intended usage to go through executorch after this change? We previously found that the executorch wrapper registers past k/v as buffers, and as a result the exported ONNX has them as initializers. |
|
It is not. I wanted to simplify the wrapper around executorch because I need their code to patch vmap and write a full test with dynamic shapes. |
|
@gante Is cache_utils.py the right place to put all export + caching related logic? I also want to add similar support for MambaCache but executorch doesn't feel like the right place to put it since it will cover other export users like ONNX. Is there dedicated directory for maybe PT-2/compile related changes? |
After the serialization functions are registered for all the caches, classes such as TorchExportableModuleWithHybridCache should not be needed anymore. |
|
cc @gante for staticcache too! |
@xadupre Just want to confirm! For individual cache classes like MambaCache, we still need to register them to pytree seperately right? |
|
Yes, the cache type is the key for the registration. If you have something like |
ArthurZucker
left a comment
There was a problem hiding this comment.
Sorry about the late review, was waiting because #39797 was just merged, which has lazy inits again (it added a way not to use it and touched the executorch integration!)
|
Can you make sure your rebase! 🤗 |
|
Hey! Just a heads-up about #40075! Pytree registration conflicts with fsdp apparently - everything should be part of |
|
@Cyrilvallez @ArthurZucker the fix is not specific to executorch, but is also needed for |
|
I think we need more general directory where we can put torch.export specific things. Executorch is one user of torch.export, so if we move the pytree registration into executorch directory, other export users can't use them easily. |
|
The fact is that export necessarily need to register correct mask function (the one without vmap) to work correctly. Otherwise export will fail. So IMO, for now at least it makes sense to have everything there |
What does this PR do?
Registers serialization functions for StaticCache. After this, torch.export.export does not need extra work to export a model using a StaticCache.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.