Fix empty tensor shape issue in DynamicCache for torch.compile#42053
Fix empty tensor shape issue in DynamicCache for torch.compile#42053yashwantbezawada wants to merge 1 commit intohuggingface:mainfrom
Conversation
8bf24e6 to
ffd4b63
Compare
Fixes huggingface#42027 This commit fixes a regression where torch.cat receives incorrectly shaped empty tensors during model tracing with torch.compile. The issue was introduced in commit dc11a3c where empty cache tensors were initialized as 1D tensors with shape [0] using torch.tensor([]). When these are concatenated with 4D key/value tensors [batch, heads, seq, dim] along dim=-2, torch.compile's tracing fails. Changes: - Modified DynamicLayer.lazy_initialization() to create properly shaped 4D empty tensors [batch, heads, 0, dim] instead of 1D [0] - Modified QuantizedLayer.update() to reset cache with proper 4D shape - Used torch.zeros() with explicit shape matching key_states dimensions This ensures torch.cat operations work correctly in both eager and compiled modes.
ffd4b63 to
1375af8
Compare
|
I see that the CI tests are failing (tests_exotic_models, tests_generate, tests_torch), while code quality checks pass. I'm unable to access the detailed CircleCI logs to understand the specific test failures. The changes I made:
This approach ensures torch.cat works correctly in torch.compile mode by providing properly shaped 4D tensors. Could someone help me understand what tests are failing and why? I'd be happy to adjust the approach if needed. I'm aware of PR #40328 which takes a more comprehensive approach to torch.compile + DynamicCache compatibility. cc @huggingface/transformers |
|
This is an update to a PR from @Cyrilvallez, so I'll wait for him to approve it! |
|
How are you using torch.compile? In general, it should not really be used with DynamicCache, as the shapes will keep on changing each iteration |
|
The original issue was about using a custom torch.compile backend with I noticed PR #40328 is doing the proper fix with symbolic shapes and mark_dynamic. That seems like the right long-term solution. Should I just close this one? |
|
Hey @yashwantbezawada! Indeed the PR you linked will provide much better support for compile options! I'll go back into it asap when things will be calmer after the |
|
Thanks for the update! Closing this one - looking forward to #40328. |
What does this PR do?
Fixes #42027
This PR fixes a regression where torch.cat receives incorrectly shaped empty tensors during GPT2 model tracing with torch.compile, causing compilation failures.
Background
The issue was introduced in commit dc11a3c (PR #39797) where empty cache tensors were initialized as 1D tensors with shape [0] using torch.tensor([]). When these are concatenated with 4D key/value tensors [batch_size, num_heads, seq_len, head_dim] along dim=-2, torch.compiles tracing fails with empty tensor errors.
Changes
Modified DynamicLayer.lazy_initialization()
Modified QuantizedLayer.update()
Testing
The fix ensures:
Impact