[rfc] Prototype to make torch.compile work with DynamicCache#40328
[rfc] Prototype to make torch.compile work with DynamicCache#40328anijain2305 wants to merge 6 commits intohuggingface:mainfrom
Conversation
Cyrilvallez
left a comment
There was a problem hiding this comment.
Nice! Super happy to see you jump on this! Indeed, we waste A LOT of tokens when using StaticCache, so would be amazing if we can compile Dynamic with nice optimizations 🚀
Finally starting to understand the usage of torch._check a bit more!
| def add_torch_size_checks(self): | ||
| if not len(self.layers): | ||
| return | ||
|
|
||
| first_keys = self.layers[0].keys | ||
| first_values = self.layers[0].values | ||
| torch._check(first_keys.size(2) == first_values.size(2)) | ||
|
|
||
| for layer in self.layers: | ||
| torch._check(first_keys.size(2) == layer.keys.size(2)) | ||
| torch._check(first_values.size(2) == layer.values.size(2)) |
There was a problem hiding this comment.
We need to be careful here, since #40039 this is not necsarily the case on all layers (we need to do it separately on full and sliding layers)
| # forward through all the layers | ||
| return len(self.layers) | ||
|
|
||
| def add_torch_size_checks(self): |
There was a problem hiding this comment.
Let's have a nice docstring here explaining that the ._check calls here are used to tell the compiler in advance that they will all be the same symbolic size, which will allow it to optimize nicely in combination with torch._dynamo.mark_dynamic on this symbolic size.
Because ._check is not documented as all, so we tried using it a few times already, but it's never clear how it works and what it allows 👌
| # TODO - In my opinion, we should keep compile disable by default, and | ||
| # then introduce `enable_compile` in the generation config to do the | ||
| # compilation. With DynamicCache becoming compileable, the blast radius | ||
| # could be big. | ||
| # Override: honor `disable_compile` flag |
There was a problem hiding this comment.
Agreed, we don't want to start compiling everything!
There was a problem hiding this comment.
+1, compilation by default places additional requirements on the user's setup, which occasionally results in issues
| assert isinstance(attention_mask, (dict, type(None))) | ||
| # With compileable caches, we get 4D masks (not sure why) |
There was a problem hiding this comment.
When compiling, we create the final masks in advance between the forwards, because compilation of the mask creation is extremely tricky (borderline impossible in all scenarios), and it's not a bottleneck anyway (you can think of it as some kind of graph break, but we just do it outside)
| # TODO - In my opinion, we should keep compile disable by default, and | ||
| # then introduce `enable_compile` in the generation config to do the | ||
| # compilation. With DynamicCache becoming compileable, the blast radius | ||
| # could be big. | ||
| # Override: honor `disable_compile` flag |
There was a problem hiding this comment.
+1, compilation by default places additional requirements on the user's setup, which occasionally results in issues
da4b3e0 to
62d9714
Compare
62d9714 to
c408872
Compare
|
@anijain2305 we are happy to help you / push this to merge it, tell us if you need anything! |
|
Yes sorry everyone, I told @anijain2305 I would take over and wanted to do it this week, but did not find the time 🥲 Will do next week! |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Alright @anijain2305, just a few questions to better understand the compiler needs/structure before continuing and going for the full integration! 🤗
| if isinstance(cache, DynamicCache): | ||
| # Mark the sequence_length dimension as dynamic | ||
| cache.mark_dynamic_for_compile() | ||
| attention_mask = model_inputs.get("attention_mask") | ||
| if isinstance(attention_mask, dict): | ||
| for mask in attention_mask.values(): | ||
| if isinstance(mask, torch.Tensor): | ||
| torch._dynamo.mark_dynamic(mask, 3) | ||
| elif isinstance(attention_mask, torch.Tensor): | ||
| torch._dynamo.mark_dynamic(mask, 3) |
There was a problem hiding this comment.
Do we need to mark them dynamic at each iteration, or could we do it only once at the beginning?
I assume we need it every iteration, but so can we move it in model_wrapper in get_compiled_call instead? Should it be outside the function that is actually being compiled? (i.e. any reason you defined it here specifically?) If possible, having all this logic in the wrapper would be simpler to understand!
| with torch.compiler.config.patch(dynamic_sources=".*cumulative_length"): | ||
| outputs = model_forward(**model_inputs, return_dict=True) |
There was a problem hiding this comment.
Any way to activate this when calling torch.compile on the forward instead of a context manager by any chance?
| torch._check(layer.keys.size(2) == dynamic_dim) | ||
| torch._check(layer.keys.values(2) == dynamic_dim) |
There was a problem hiding this comment.
Just setting them all to the same value in this way is enough right? Or does dynamo relies on the cross-objects dependencies somehow?
|
@anijain2305 we are happy to push this further! |
What does this PR do?
Background on StaticCache vs DynamicCache
A static KV cache must be preallocated for the model’s maximum context window (e.g., 32k tokens), even if you only generate a few tokens. That leads to large memory pressure and wastage. A dynamic cache grows with the sequence actually in use: with P prefill tokens and D decode steps, memory scales with P + D instead of max_seq_length. For example, if you prefill 2,000 tokens and generate 50, the dynamic cache holds ~2,050 tokens’ worth of K/V per layer, versus 32k for a static allocation, significantly reducing memory pressure.
Lets do a quick experiment for Qwen3
In all cases, the output generated text was exactly the same.
As expected, StaticCache has a higher memory footprint. But the generation time is also higher. This is because the main tensors involved in SDPA operation - (key, value, attention_mask) have max seq length shape. Even though attention_mask fires up only a portion of the actual compute, there is still more work to be done here compared to Dynamic KV cache. (Side chat - Driss and Boyuan told me that we might be able to use better eager code or flex decoding to make this better, but thats not the focus of this doc.)
torch.compile and Dynamic Cache
transformers introduced StaticCache just for torch.compile. For Qwen3, torch.compile latency is 20 seconds, achieving 1.2x speedup and 25 seconds of compile time (tlparse).
But as pointed out earlier, DynamicCache is a better default setting for eager PyTorch. So, we will look at the changes required to make torch.compile work with DynamicCache. WIth this PR, I see torch.compile improve generation latency to 3.6 (while eager was 8.88 seconds), leading to 2.4x speedup and a cold start latency of 36 seconds (tlparse) and warm start of 11 seconds.
For this RFC
There are two main changes
Avoid recompilations (symbolic seq length) : Tell torch.compile to treat sequence length as dynamic for all KV-cache–related tensors - key, value, and attention_mask, using torch._dynamo.mark_dynamic. To ensure a single shared symbolic size is used everywhere (since mark_dynamic alone can yield distinct symbols), wrap the model and insert explicit size-equality checks (e.g., via torch._check) that tie the symbols together. This nudges the compiler to unify on one seq-len symbol, improving inductor codegen and in future compile time.
CUDA graphs: Disable cudagraphs for DynamicCache. Because the sequence length changes step-to-step, cudagraphs would constantly re-record, which undercuts their benefit and can hurt latency. (They remain useful when shapes are fixed, e.g., with a static cache.)
Update
The PR now extends torch.compile to
DynamicSlidingWindowLayer. There is one extra recompilation on the integer value ofcumulative_lengththat I hope can be prevented by converting the integer into a scalar tensor.Speedup - Pytorch eager has 6.66 seconds, compile takes 5.10 seconds for Gemma2. Although the compile time is high (currently due to 2 compilations).