Skip to content

Registers StaticCache serialization functions for torch.export.export#39931

Open
xadupre wants to merge 8 commits intohuggingface:mainfrom
xadupre:series
Open

Registers StaticCache serialization functions for torch.export.export#39931
xadupre wants to merge 8 commits intohuggingface:mainfrom
xadupre:series

Conversation

@xadupre
Copy link
Copy Markdown
Contributor

@xadupre xadupre commented Aug 5, 2025

What does this PR do?

Registers serialization functions for StaticCache. After this, torch.export.export does not need extra work to export a model using a StaticCache.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xadupre
Copy link
Copy Markdown
Contributor Author

xadupre commented Aug 5, 2025

cc @yuanyao-nv
cc @tugsbayasgalan, this PR should simplify the export for Executorch + StaticCache

@yuanyao-nv
Copy link
Copy Markdown

@xadupre Is the intended usage to go through executorch after this change? We previously found that the executorch wrapper registers past k/v as buffers, and as a result the exported ONNX has them as initializers.

@xadupre
Copy link
Copy Markdown
Contributor Author

xadupre commented Aug 5, 2025

It is not. I wanted to simplify the wrapper around executorch because I need their code to patch vmap and write a full test with dynamic shapes.

@tugsbayasgalan
Copy link
Copy Markdown
Contributor

@gante Is cache_utils.py the right place to put all export + caching related logic? I also want to add similar support for MambaCache but executorch doesn't feel like the right place to put it since it will cover other export users like ONNX. Is there dedicated directory for maybe PT-2/compile related changes?

@xadupre
Copy link
Copy Markdown
Contributor Author

xadupre commented Aug 6, 2025

@gante Is cache_utils.py the right place to put all export + caching related logic? I also want to add similar support for MambaCache but executorch doesn't feel like the right place to put it since it will cover other export users like ONNX. Is there dedicated directory for maybe PT-2/compile related changes?

After the serialization functions are registered for all the caches, classes such as TorchExportableModuleWithHybridCache should not be needed anymore.

@Rocketknight1
Copy link
Copy Markdown
Member

cc @gante for staticcache too!

@tugsbayasgalan
Copy link
Copy Markdown
Contributor

@gante Is cache_utils.py the right place to put all export + caching related logic? I also want to add similar support for MambaCache but executorch doesn't feel like the right place to put it since it will cover other export users like ONNX. Is there dedicated directory for maybe PT-2/compile related changes?

After the serialization functions are registered for all the caches, classes such as TorchExportableModuleWithHybridCache should not be needed anymore.

@xadupre Just want to confirm! For individual cache classes like MambaCache, we still need to register them to pytree seperately right?

@xadupre
Copy link
Copy Markdown
Contributor Author

xadupre commented Aug 6, 2025

Yes, the cache type is the key for the registration. If you have something like class MyDynamicCache(DynamicCache): pass, you need to registrer MyDynamicCache.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about the late review, was waiting because #39797 was just merged, which has lazy inits again (it added a way not to use it and touched the executorch integration!)

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Can you make sure your rebase! 🤗

@Cyrilvallez
Copy link
Copy Markdown
Member

Hey! Just a heads-up about #40075! Pytree registration conflicts with fsdp apparently - everything should be part of executorch integration only for this reason (and the fact that since it's specific to export, it should reside there as well 🤗)

@justinchuby
Copy link
Copy Markdown
Contributor

justinchuby commented Aug 21, 2025

@Cyrilvallez @ArthurZucker the fix is not specific to executorch, but is also needed for torch.export compatibility that is used by ONNX export and {xla etc.}. Is there a path to natively support this in transformers?

@tugsbayasgalan
Copy link
Copy Markdown
Contributor

I think we need more general directory where we can put torch.export specific things. Executorch is one user of torch.export, so if we move the pytree registration into executorch directory, other export users can't use them easily.

@Cyrilvallez
Copy link
Copy Markdown
Member

The fact is that export necessarily need to register correct mask function (the one without vmap) to work correctly. Otherwise export will fail. So IMO, for now at least it makes sense to have everything there

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants