Skip to content

[rfc] Prototype to make torch.compile work with DynamicCache#40328

Open
anijain2305 wants to merge 6 commits intohuggingface:mainfrom
anijain2305:dynamic-cache-compile
Open

[rfc] Prototype to make torch.compile work with DynamicCache#40328
anijain2305 wants to merge 6 commits intohuggingface:mainfrom
anijain2305:dynamic-cache-compile

Conversation

@anijain2305
Copy link
Copy Markdown
Contributor

@anijain2305 anijain2305 commented Aug 20, 2025

What does this PR do?

Background on StaticCache vs DynamicCache

A static KV cache must be preallocated for the model’s maximum context window (e.g., 32k tokens), even if you only generate a few tokens. That leads to large memory pressure and wastage. A dynamic cache grows with the sequence actually in use: with P​ prefill tokens and D decode steps, memory scales with P + D instead of max_seq_length. For example, if you prefill 2,000 tokens and generate 50, the dynamic cache holds ~2,050 tokens’ worth of K/V per layer, versus 32k for a static allocation, significantly reducing memory pressure.

Lets do a quick experiment for Qwen3

image

In all cases, the output generated text was exactly the same.

As expected, StaticCache has a higher memory footprint. But the generation time is also higher. This is because the main tensors involved in SDPA operation - (key, value, attention_mask) have max seq length shape. Even though attention_mask fires up only a portion of the actual compute, there is still more work to be done here compared to Dynamic KV cache. (Side chat - Driss and Boyuan told me that we might be able to use better eager code or flex decoding to make this better, but thats not the focus of this doc.)

torch.compile and Dynamic Cache

transformers introduced StaticCache just for torch.compile. For Qwen3, torch.compile latency is 20 seconds, achieving 1.2x speedup and 25 seconds of compile time (tlparse).

But as pointed out earlier, DynamicCache is a better default setting for eager PyTorch. So, we will look at the changes required to make torch.compile work with DynamicCache. WIth this PR, I see torch.compile improve generation latency to 3.6 (while eager was 8.88 seconds), leading to 2.4x speedup and a cold start latency of 36 seconds (tlparse) and warm start of 11 seconds.

For this RFC

  • We will use fullgraph=True and do full model compilation (no regional compilation yet).
  • We will compile only the decode step to keep the discussion simple.

There are two main changes
Avoid recompilations (symbolic seq length) : Tell torch.compile to treat sequence length as dynamic for all KV-cache–related tensors - key, value, and attention_mask, using torch._dynamo.mark_dynamic. To ensure a single shared symbolic size is used everywhere (since mark_dynamic alone can yield distinct symbols), wrap the model and insert explicit size-equality checks (e.g., via torch._check) that tie the symbols together. This nudges the compiler to unify on one seq-len symbol, improving inductor codegen and in future compile time.

CUDA graphs: Disable cudagraphs for DynamicCache. Because the sequence length changes step-to-step, cudagraphs would constantly re-record, which undercuts their benefit and can hurt latency. (They remain useful when shapes are fixed, e.g., with a static cache.)

Update

The PR now extends torch.compile to DynamicSlidingWindowLayer. There is one extra recompilation on the integer value of cumulative_length that I hope can be prevented by converting the integer into a scalar tensor.

Speedup - Pytorch eager has 6.66 seconds, compile takes 5.10 seconds for Gemma2. Although the compile time is high (currently due to 2 compilations).

@anijain2305
Copy link
Copy Markdown
Contributor Author

cc @ArthurZucker @gante

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Super happy to see you jump on this! Indeed, we waste A LOT of tokens when using StaticCache, so would be amazing if we can compile Dynamic with nice optimizations 🚀
Finally starting to understand the usage of torch._check a bit more!

Comment thread src/transformers/cache_utils.py Outdated
Comment on lines +968 to +978
def add_torch_size_checks(self):
if not len(self.layers):
return

first_keys = self.layers[0].keys
first_values = self.layers[0].values
torch._check(first_keys.size(2) == first_values.size(2))

for layer in self.layers:
torch._check(first_keys.size(2) == layer.keys.size(2))
torch._check(first_values.size(2) == layer.values.size(2))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be careful here, since #40039 this is not necsarily the case on all layers (we need to do it separately on full and sliding layers)

Comment thread src/transformers/cache_utils.py Outdated
# forward through all the layers
return len(self.layers)

def add_torch_size_checks(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a nice docstring here explaining that the ._check calls here are used to tell the compiler in advance that they will all be the same symbolic size, which will allow it to optimize nicely in combination with torch._dynamo.mark_dynamic on this symbolic size.

Because ._check is not documented as all, so we tried using it a few times already, but it's never clear how it works and what it allows 👌

Comment thread src/transformers/generation/utils.py Outdated
Comment on lines 2108 to 2111
# TODO - In my opinion, we should keep compile disable by default, and
# then introduce `enable_compile` in the generation config to do the
# compilation. With DynamicCache becoming compileable, the blast radius
# could be big.
# Override: honor `disable_compile` flag
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we don't want to start compiling everything!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, compilation by default places additional requirements on the user's setup, which occasionally results in issues

Comment thread src/transformers/generation/utils.py Outdated
assert isinstance(attention_mask, (dict, type(None)))
# With compileable caches, we get 4D masks (not sure why)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When compiling, we create the final masks in advance between the forwards, because compilation of the mask creation is extremely tricky (borderline impossible in all scenarios), and it's not a bottleneck anyway (you can think of it as some kind of graph break, but we just do it outside)

Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! 🔥

Comment thread src/transformers/generation/utils.py Outdated
Comment on lines 2108 to 2111
# TODO - In my opinion, we should keep compile disable by default, and
# then introduce `enable_compile` in the generation config to do the
# compilation. With DynamicCache becoming compileable, the blast radius
# could be big.
# Override: honor `disable_compile` flag
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, compilation by default places additional requirements on the user's setup, which occasionally results in issues

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@anijain2305 anijain2305 force-pushed the dynamic-cache-compile branch 2 times, most recently from da4b3e0 to 62d9714 Compare August 26, 2025 06:20
@anijain2305 anijain2305 force-pushed the dynamic-cache-compile branch from 62d9714 to c408872 Compare August 27, 2025 17:32
@ArthurZucker
Copy link
Copy Markdown
Collaborator

@anijain2305 we are happy to help you / push this to merge it, tell us if you need anything!

@Cyrilvallez
Copy link
Copy Markdown
Member

Yes sorry everyone, I told @anijain2305 I would take over and wanted to do it this week, but did not find the time 🥲 Will do next week!

@ArthurZucker ArthurZucker added Compilation Issues related to torchdynamo and torchinductor Core: Modeling Internals of the library; Models. labels Sep 15, 2025
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright @anijain2305, just a few questions to better understand the compiler needs/structure before continuing and going for the full integration! 🤗

Comment on lines +2780 to +2789
if isinstance(cache, DynamicCache):
# Mark the sequence_length dimension as dynamic
cache.mark_dynamic_for_compile()
attention_mask = model_inputs.get("attention_mask")
if isinstance(attention_mask, dict):
for mask in attention_mask.values():
if isinstance(mask, torch.Tensor):
torch._dynamo.mark_dynamic(mask, 3)
elif isinstance(attention_mask, torch.Tensor):
torch._dynamo.mark_dynamic(mask, 3)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to mark them dynamic at each iteration, or could we do it only once at the beginning?
I assume we need it every iteration, but so can we move it in model_wrapper in get_compiled_call instead? Should it be outside the function that is actually being compiled? (i.e. any reason you defined it here specifically?) If possible, having all this logic in the wrapper would be simpler to understand!

Comment on lines +2796 to +2797
with torch.compiler.config.patch(dynamic_sources=".*cumulative_length"):
outputs = model_forward(**model_inputs, return_dict=True)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way to activate this when calling torch.compile on the forward instead of a context manager by any chance?

Comment on lines +1043 to +1044
torch._check(layer.keys.size(2) == dynamic_dim)
torch._check(layer.keys.values(2) == dynamic_dim)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just setting them all to the same value in this way is enough right? Or does dynamo relies on the cross-objects dependencies somehow?

@ArthurZucker
Copy link
Copy Markdown
Collaborator

@anijain2305 we are happy to push this further!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Compilation Issues related to torchdynamo and torchinductor Core: Modeling Internals of the library; Models.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants