-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Implement TeaCache #12652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement TeaCache #12652
Conversation
|
Work done
Waiting for feedback and review :) |
|
Hi @sayakpaul @dhruvrnaik any updates? |
|
@LawJarp-A sorry about the delay on our end. @DN6 will review it soon. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this. |
Yep @DN6 , I agree, I wanted to first implement it just for a single model and get feedback on that before I work on Model agnostic full implementation. I'm sort of working on it, didn't push it yet. I'll take a look at First block cache for reference as well. |
|
@DN6 updated it in a more model agnostic way. |
…th auto-detection
|
Added multi model support, testing it thoroughly though. |
|
Hi @DN6 @sayakpaul
In the meantime any feedback would be appreciated |
|
Thanks @LawJarp-A!
You can refer to #12569 for testing
Yes, I think that is informative for users. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial feedback. Most important question is it seems like we need to craft different logic based on different model? Can we not keep it model agnostic?
|
I am trying to think if ways we can avoid having the forward model for each model now. Initially that seemed like th ebe
t was fine when I wrote for flux, but lumina needed multi stage preprocessing. |
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
|
@sayakpaul @DN6 checking in again :) |
DN6
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some high level feedback on the design. The control flow is hard to follow as it switches between the hook object and adapter. The adapters themselves are thin wrappers around a modified forward function, so it would be better to just define them as standalone functions. e.g.
def _flux_forward(
state: "TeaCacheState", # pass the state to the function not the hook object
coefficients: List[float],
rel_l1_thresh: float,
module: torch.nn.Module,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
pooled_projections: torch.Tensor,
encoder_hidden_states: torch.Tensor,
txt_ids: torch.Tensor,
img_ids: torch.Tensor,
return_dict: bool = True,
**kwargs,
):
if _should_use_cache(state, modulated_inp, coefficients, rel_l1_thresh)
hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp)
else:
# run compute
_update_cache(state, hidden_states, original_hidden_states, modulated_inp)Since we're hooking the top level forward of the model, we can map this forward function using the class name during hook initialization.
def initialize_hook(self, module):
"""Initialize hook with model-specific configuration."""
model_config = _MODEL_CONFIG.get(module.__name__)
if model_config is None:
raise ValueError
if self.config.coefficients is not None:
self.coefficients = self.config.coefficients
else:
self.coefficients = model_config["coefficients"]
# Initialize state
self.state_manager = StateManager(TeaCacheState)
self.forward_fn = model_config["forward_func"]
return moduleWhere _MODEL_CONFIG is just a mapping for the forward functions and coefficients
_MODEL_CONFIG = {
"FluxTransformer2DModel": {
"forward_func": _flux_forward,
"coefficients": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
},
}Similarly, the methods defined in the hook object could also be turned into utility functions.
def _compute_rescaled_distance(rel_distance: float, coefficients: List[float]) -> float:
return (
coefficients[0] * rel_distance**4
+ coefficients[1] * rel_distance**3
+ coefficients[2] * rel_distance**2
+ coefficients[3] * rel_distance
+ coefficients[4]
)
def _should_use_cache(state: "TeaCacheState", ...):
# Return True or False based on whether to use cache.
return
def _update_cache(state: "TeaCacheState)
return
def _apply_cached_residual(
state: "TeaCacheState", input_base: torch.Tensor, modulated_inp: torch.Tensor
) -> torch.Tensor:
"""
Apply cached residual to input (fast path).
"""
output = input_base + state.previous_residual
state.previous_modulated_input = modulated_inp
state.cnt += 1
return outputLet's remove passing cache_fn and compute_fn between the hook and the adapter. Use operations directly on the cache state + globally available utility methods. We can also remove the modulation extractors and move that logic into the model specific forward functions.
|
Thanks for the feedback @DN6 |
|
The per-model forward code is unavoidable due to different model architectures. The adapter pattern was an attempt to organize this, but I agree standalone functions would be cleaner. I'll refactor. |
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
…ctions Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
|
Hi @DN6 , I've updated the implementation as you requested:
This does introduce some code duplication - each forward function now has the same if/else pattern: if _should_compute(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh):
# compute full transformer
_update_state(state, output, original, modulated_inp)
else:
output = _apply_cached_residual(state, input, modulated_inp)But the control flow is now much clearer - you can read each forward function top-to-bottom without jumping between closures and hook methods. Let me know if you'd like any further changes! |
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
… isolation Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
…elpers Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
|
@DN6 @sayakpaul I spent the weekend going over the code again to understand and simplify
I have kept it with per model forward function like you requested instead of the common adapter pattern I was using before. Btw, below are the images generated w and w/o cache |
…orward methods Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments. LMK if they make sense.
| # Fallback to default context for backward compatibility with | ||
| # pipelines that don't call cache_context() | ||
| context = "_default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this branch not error out like previous?
| # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them. | ||
| """Set context on all StateManager attributes of this hook.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please revert the changes unrelated to this PR? Makes reviewing a bit easier since the diff becomes smaller.
| def _get_model_config() -> Dict[str, Dict[str, Any]]: | ||
| """Get model configuration mapping. Order matters: more specific variants before generic ones.""" | ||
| return { | ||
| "FluxKontext": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT of using the actual model class names here? Will likely be easier to maintain and read.
| return { | ||
| "FluxKontext": { | ||
| "forward_func": _flux_teacache_forward, | ||
| "coefficients": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): would make a note on how these values were obtained.
| if prev_mean.item() > 1e-9: | ||
| return ((current - previous).abs().mean() / prev_mean).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to make it data-dependent (item() call)? Raising it because it makes torch.compile cry.
| return 0.0 if current.abs().mean().item() < 1e-9 else float("inf") | ||
|
|
||
|
|
||
| @torch.compiler.disable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because of the item() call?
| attention_kwargs, lora_scale = _extract_lora_scale(attention_kwargs) | ||
| if USE_PEFT_BACKEND: | ||
| scale_lora_layers(module, lora_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check if the underlying model class inherits from PeftLoaderMixin and if so, we should do it.
| """Get model configuration mapping. Order matters: more specific variants before generic ones.""" | ||
| return { | ||
| "FluxKontext": { | ||
| "forward_func": _flux_teacache_forward, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One problem that I see with this is what happens when the core forward method undergoes some changes and we fail to propagate them in these modified forwards.
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements TeaCache (Timestep Embedding Aware Cache), a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.
Changes:
- Adds TeaCache hook system with model-specific forward implementations for FLUX, Mochi, Lumina2, and CogVideoX models
- Integrates TeaCache with the existing CacheMixin infrastructure for unified cache management
- Implements StateManager improvements for context-aware state isolation (CFG support)
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| src/diffusers/hooks/teacache.py | Core TeaCache implementation with polynomial rescaling, model auto-detection, and specialized forward functions for each supported model |
| src/diffusers/models/cache_utils.py | Integration of TeaCacheConfig into enable_cache/disable_cache methods |
| src/diffusers/hooks/init.py | Export TeaCacheConfig, apply_teacache, and StateManager |
| src/diffusers/hooks/hooks.py | StateManager enhancement with default context fallback for backward compatibility |
| src/diffusers/models/transformers/transformer_lumina2.py | Add CacheMixin to Lumina2Transformer2DModel |
| tests/hooks/test_teacache.py | Comprehensive unit tests for config validation, state management, and model detection |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| } | ||
| cache = state.cache_dict[cache_key] | ||
|
|
||
| is_boundary_step = state.cnt == 0 or state.cnt == state.num_steps - 1 |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The boundary step check may fail when num_steps is 0 (not initialized). If state.num_steps is 0, the condition state.cnt == state.num_steps - 1 evaluates to state.cnt == -1, which would never be true for a non-negative counter. Consider adding state.num_steps > 0 as a guard similar to line 118.
| is_boundary_step = state.cnt == 0 or state.cnt == state.num_steps - 1 | |
| is_boundary_step = state.cnt == 0 or (state.num_steps > 0 and state.cnt == state.num_steps - 1) |
| # Track sequence length for step counting (CFG) | ||
| if state.uncond_seq_len is None: | ||
| state.uncond_seq_len = cache_key | ||
| if cache_key != state.uncond_seq_len: |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The step counter logic for Lumina2 is confusing and potentially incorrect. The counter is incremented when cache_key != state.uncond_seq_len (line 725), but this assumes CFG will always process unconditional first, then conditional. If the order changes or if CFG is not used, the counter may not increment correctly. Consider documenting this assumption more clearly or making the logic more robust.
| # Track sequence length for step counting (CFG) | |
| if state.uncond_seq_len is None: | |
| state.uncond_seq_len = cache_key | |
| if cache_key != state.uncond_seq_len: | |
| # Track sequence length for step counting. | |
| # We keep state.uncond_seq_len for backward compatibility but avoid assuming | |
| # that unconditional is always processed before conditional. | |
| if state.uncond_seq_len is None: | |
| state.uncond_seq_len = cache_key | |
| # Step counting strategy: | |
| # * If we ever observe more than one distinct cache_key, we assume CFG-like | |
| # behavior and increment state.cnt once per "pair" (i.e., once we've seen | |
| # at least two distinct cache_keys in a step), independent of ordering. | |
| # * If we only ever observe a single cache_key, we assume CFG is disabled and | |
| # increment state.cnt on every call. | |
| if not hasattr(state, "_seen_cache_keys_in_step"): | |
| state._seen_cache_keys_in_step = set() | |
| if not hasattr(state, "_all_cache_keys_seen"): | |
| state._all_cache_keys_seen = set() | |
| state._seen_cache_keys_in_step.add(cache_key) | |
| state._all_cache_keys_seen.add(cache_key) | |
| has_multiple_cache_keys = len(state._all_cache_keys_seen) > 1 | |
| if has_multiple_cache_keys: | |
| # CFG-like behavior: increment once per step after seeing multiple keys. | |
| if len(state._seen_cache_keys_in_step) > 1: | |
| state.cnt += 1 | |
| if state.cnt >= state.num_steps: | |
| state.cnt = 0 | |
| # Reset for the next diffusion step. | |
| state._seen_cache_keys_in_step.clear() | |
| else: | |
| # No-CFG behavior: only one cache_key is ever seen, so increment every call. |
| def _auto_detect_model_type(module: torch.nn.Module) -> str: | ||
| """Auto-detect model type from class name and config path.""" | ||
| class_name = module.__class__.__name__ | ||
| config_path = getattr(getattr(module, "config", None), "_name_or_path", "").lower() | ||
| model_config = _get_model_config() | ||
|
|
||
| for model_type in model_config: | ||
| if model_type.lower() in config_path or model_type in class_name: | ||
| return model_type |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model auto-detection uses substring matching with model_type.lower() in config_path or model_type in class_name. This can lead to false positives. For example, "CogVideoX" would match "CogVideoX1.5-5B" configs. While the iteration order is designed to check specific variants first, the order dependency is fragile. Consider using more precise matching (e.g., checking for exact model identifier patterns or using startswith/endswith) to avoid potential mismatches.
| f"Please provide a float value between 0.1 and 1.0." | ||
| ) | ||
| if self.rel_l1_thresh <= 0: | ||
| raise ValueError( | ||
| f"rel_l1_thresh must be positive, got {self.rel_l1_thresh}. " | ||
| f"Based on the TeaCache paper, values between 0.1 and 0.3 work best. " | ||
| f"Try 0.25 for 1.5x speedup or 0.6 for 2x speedup." |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation logic checks for rel_l1_thresh <= 0 but zero values are arguably valid since they would force computation at every step (effectively disabling caching). Consider whether the check should be < 0 instead, or document why zero is explicitly disallowed.
| f"Please provide a float value between 0.1 and 1.0." | |
| ) | |
| if self.rel_l1_thresh <= 0: | |
| raise ValueError( | |
| f"rel_l1_thresh must be positive, got {self.rel_l1_thresh}. " | |
| f"Based on the TeaCache paper, values between 0.1 and 0.3 work best. " | |
| f"Try 0.25 for 1.5x speedup or 0.6 for 2x speedup." | |
| f"Please provide a float value >= 0.0 (values between 0.1 and 1.0 are recommended)." | |
| ) | |
| if self.rel_l1_thresh < 0: | |
| raise ValueError( | |
| f"rel_l1_thresh must be non-negative, got {self.rel_l1_thresh}. " | |
| f"Based on the TeaCache paper, values between 0.1 and 0.3 work best. " | |
| f"Try 0.25 for 1.5x speedup or 0.6 for 2x speedup. " | |
| f"Note that rel_l1_thresh=0.0 effectively disables caching by forcing computation at every step." |
| state.cnt = 0 | ||
| state.accumulated_rel_l1_distance = 0.0 | ||
| state.previous_modulated_input = None | ||
| state.previous_residual = None |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _maybe_reset_state_for_new_inference method doesn't reset cache_dict and uncond_seq_len which are used by Lumina2. This could cause stale cache data to persist across inference runs when using Lumina2 models. Consider calling state.reset() instead of manually resetting individual fields, or add these Lumina2-specific fields to the reset logic.
| state.previous_residual = None | |
| state.previous_residual = None | |
| # Reset Lumina2-specific state to avoid stale cache/data between inference runs | |
| if hasattr(state, "cache_dict") and state.cache_dict is not None: | |
| # Clear in-place to preserve any existing references to the cache dict | |
| state.cache_dict.clear() | |
| if hasattr(state, "uncond_seq_len"): | |
| state.uncond_seq_len = None |
| Example: | ||
| ```python | ||
| >>> from diffusers import FluxPipeline | ||
| >>> from diffusers.hooks import TeaCacheConfig | ||
| >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | ||
| >>> pipe.to("cuda") | ||
| >>> config = TeaCacheConfig(rel_l1_thresh=0.2) | ||
| >>> pipe.transformer.enable_cache(config) | ||
| >>> image = pipe("A cat sitting on a windowsill", num_inference_steps=4).images[0] | ||
| >>> pipe.transformer.disable_cache() | ||
| ``` |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example code in the docstring references torch but doesn't show the import statement. Consider adding import torch to the example for completeness.




What does this PR do?
What is TeaCache?
TeaCache (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.
Architecture & Design
TeaCache uses a
ModelHookto intercept transformer forward passes without modifying model code. The algorithm:c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]Key Design Features:
HookRegistryandCacheMixinfor lifecycle managementStateManagerwith context-aware state for CFG conditional/unconditional branchesSupported Models
All models support automatic coefficient detection based on model class name and config path. Custom coefficients can also be provided via
TeaCacheConfig.Benchmark Results (FLUX.1-dev)
Benchmark Results (Lumina2)
Benchmark Results (CogVideoX-2b)
Benchmark Results (Mochi)
Test Hardware: NVIDIA h100
Framework: Diffusers with TeaCache hooks
All tests: Same seed (42) for reproducibility
Usage
Configuration Options
The
TeaCacheConfigsupports the following parameters:rel_l1_thresh(float, default=0.2): Threshold for accumulated relative L1 distance. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, 0.6 for ~2.0x. Mochi models require lower thresholds (0.06-0.09).coefficients(List[float], optional): Polynomial coefficients for rescaling L1 distance. Auto-detected based on model type if not provided.num_inference_steps(int, optional): Total inference steps. Ensures first/last timesteps are always computed. Auto-detected if not provided.num_inference_steps_callback(Callable[[], int], optional): Callback returning total inference steps. Alternative tonum_inference_steps.current_timestep_callback(Callable[[], int], optional): Callback returning current timestep. Used for debugging/statistics.Files Changed
src/diffusers/hooks/teacache.py- Core implementation with model-specific forward functionssrc/diffusers/models/cache_utils.py- CacheMixin integrationsrc/diffusers/hooks/__init__.py- Export TeaCacheConfig and apply_teacachetests/hooks/test_teacache.py- Comprehensive unit testsFixes # (issue)
#12589
#12635
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @yiyixuxu @DN6