From 781c7e3cfe3b509dc58bcfa4aaf508625ef858e2 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Thu, 13 Nov 2025 12:08:52 +0530 Subject: [PATCH 01/25] base implement teacahce flux; follow original impl --- src/diffusers/hooks/flux_teacache.py | 141 +++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 src/diffusers/hooks/flux_teacache.py diff --git a/src/diffusers/hooks/flux_teacache.py b/src/diffusers/hooks/flux_teacache.py new file mode 100644 index 000000000000..0fffdaeac63e --- /dev/null +++ b/src/diffusers/hooks/flux_teacache.py @@ -0,0 +1,141 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable, List, Optional + +import numpy as np +import torch + +from .hooks import BaseState, HookRegistry, ModelHook, StateManager + + +@dataclass +class FluxTeaCacheConfig: + """Configuration for FLUX TeaCache following original algorithm.""" + rel_l1_thresh: float = 0.2 # threshold for accumulated distance (based on paper 0.1->0.3 works best) + coefficients: Optional[List[float]] = None # FLUX-specific polynomial coefficients + current_timestep_callback: Optional[Callable[[], int]] = None + + def __post_init__(self): + if self.coefficients is None: + # original FLUX coefficients from TeaCache paper + self.coefficients = [4.98651651e+02, -2.83781631e+02, + 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + + +class FluxTeaCacheState(BaseState): + """State management following original TeaCache implementation.""" + def __init__(self): + self.cnt = 0 # Current timestep counter + self.num_steps = 0 # Total inference steps + self.accumulated_rel_l1_distance = 0.0 # Running accumulator + self.previous_modulated_input = None # Previous timestep modulated features + self.previous_residual = None # cached transformer residual + + def reset(self): + self.cnt = 0 + self.accumulated_rel_l1_distance = 0.0 + self.previous_modulated_input = None + self.previous_residual = None + + +class FluxTeaCacheHook(ModelHook): + """Main hook implementing FLUX TeaCache logic.""" + + _is_stateful = True + + def __init__(self, config: FluxTeaCacheConfig): + super().__init__() + self.config = config + self.rescale_func = np.poly1d(config.coefficients) + self.state_manager = StateManager(FluxTeaCacheState, (), {}) + + def initialize_hook(self, module): + self.state_manager.set_context("flux_teacache") + return module + + def new_forward(self, module, hidden_states, timestep, pooled_projections, + encoder_hidden_states, txt_ids, img_ids, **kwargs): + """Replace FLUX transformer forward with TeaCache logic.""" + state = self.state_manager.get_state() + + # Extract timestep embedding for FLUX + timestep = timestep.to(hidden_states.dtype) * 1000 + temb = module.time_text_embed(timestep, pooled_projections) + + # Extract modulated input from first transformer block + modulated_inp, _, _, _, _ = module.transformer_blocks[0].norm1( + hidden_states.clone(), emb=temb.clone() + ) + + # Make caching decision using original algorithm + should_calc = self._should_compute_full_transformer(state, modulated_inp) + + if not should_calc: + # Fast path: apply cached residual + output = hidden_states + state.previous_residual + else: + # Slow path: full computation + ori_hidden_states = hidden_states.clone() + output = self.fn_ref.original_forward( + hidden_states, timestep, pooled_projections, + encoder_hidden_states, txt_ids, img_ids, **kwargs + ) + # Cache the residual + state.previous_residual = output - ori_hidden_states + state.previous_modulated_input = modulated_inp + + state.cnt += 1 + return output + + def _should_compute_full_transformer(self, state, modulated_inp): + """Core caching decision logic from original TeaCache.""" + # Compute first and last timesteps (always compute) + if state.cnt == 0 or state.cnt == state.num_steps - 1: + state.accumulated_rel_l1_distance = 0 + return True + + # Compute relative L1 distance + rel_distance = ((modulated_inp - state.previous_modulated_input).abs().mean() + / state.previous_modulated_input.abs().mean()).cpu().item() + + # Apply polynomial rescaling + rescaled_distance = self.rescale_func(rel_distance) + state.accumulated_rel_l1_distance += rescaled_distance + + # Make decision based on accumulated threshold + if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh: + return False + else: + state.accumulated_rel_l1_distance = 0 # Reset accumulator + return True + + def reset_state(self, module): + self.state_manager.reset() + return module + + +def apply_flux_teacache(module, config: FluxTeaCacheConfig): + """Apply TeaCache to FLUX transformer following diffusers patterns.""" + from ..models.transformers.transformer_flux import FluxTransformer2DModel + + # Validate FLUX model + if not isinstance(module, FluxTransformer2DModel): + raise ValueError("TeaCache supports only FLUX transformer model for now") + + # Register hook on main transformer + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = FluxTeaCacheHook(config) + registry.register_hook(hook, "flux_teacache") From 549bf97731beaf33b4a05d289e87c63f91d995cf Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Thu, 13 Nov 2025 21:49:27 +0530 Subject: [PATCH 02/25] update cache utils with teacache --- src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/flux_teacache.py | 35 +++++++++++++++++++++++++--- src/diffusers/models/cache_utils.py | 19 ++++++++++++++- 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 524a92ea9966..ecec7322fac9 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -19,6 +19,7 @@ from .context_parallel import apply_context_parallel from .faster_cache import FasterCacheConfig, apply_faster_cache from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache + from .flux_teacache import FluxTeaCacheConfig, apply_flux_teacache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip diff --git a/src/diffusers/hooks/flux_teacache.py b/src/diffusers/hooks/flux_teacache.py index 0fffdaeac63e..f548d32c7b74 100644 --- a/src/diffusers/hooks/flux_teacache.py +++ b/src/diffusers/hooks/flux_teacache.py @@ -21,12 +21,16 @@ from .hooks import BaseState, HookRegistry, ModelHook, StateManager +_FLUX_TEACACHE_HOOK = "flux_teacache" + + @dataclass class FluxTeaCacheConfig: """Configuration for FLUX TeaCache following original algorithm.""" rel_l1_thresh: float = 0.2 # threshold for accumulated distance (based on paper 0.1->0.3 works best) coefficients: Optional[List[float]] = None # FLUX-specific polynomial coefficients current_timestep_callback: Optional[Callable[[], int]] = None + num_inference_steps_callback: Optional[Callable[[], int]] = None # Callback to get total inference steps def __post_init__(self): if self.coefficients is None: @@ -46,6 +50,7 @@ def __init__(self): def reset(self): self.cnt = 0 + self.num_steps = 0 self.accumulated_rel_l1_distance = 0.0 self.previous_modulated_input = None self.previous_residual = None @@ -71,6 +76,21 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, """Replace FLUX transformer forward with TeaCache logic.""" state = self.state_manager.get_state() + # Reset counter if we've completed all steps (new inference run) + if state.cnt == state.num_steps and state.num_steps > 0: + state.cnt = 0 + state.accumulated_rel_l1_distance = 0.0 + state.previous_modulated_input = None + state.previous_residual = None + + # Set num_steps on first timestep if not already set + if state.cnt == 0 and state.num_steps == 0: + if self.config.num_inference_steps_callback is not None: + state.num_steps = self.config.num_inference_steps_callback() + # If still not set, try to get from module attribute (set by pipeline) + if state.num_steps == 0 and hasattr(module, 'num_steps'): + state.num_steps = module.num_steps + # Extract timestep embedding for FLUX timestep = timestep.to(hidden_states.dtype) * 1000 temb = module.time_text_embed(timestep, pooled_projections) @@ -102,10 +122,19 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, def _should_compute_full_transformer(self, state, modulated_inp): """Core caching decision logic from original TeaCache.""" - # Compute first and last timesteps (always compute) - if state.cnt == 0 or state.cnt == state.num_steps - 1: + # compute first timestep + if state.cnt == 0: state.accumulated_rel_l1_distance = 0 return True + + # compute last timestep (if num_steps is set) + if state.num_steps > 0 and state.cnt == state.num_steps - 1: + state.accumulated_rel_l1_distance = 0 + return True + + # Need previous modulated input for comparison + if state.previous_modulated_input is None: + return True # Compute relative L1 distance rel_distance = ((modulated_inp - state.previous_modulated_input).abs().mean() @@ -138,4 +167,4 @@ def apply_flux_teacache(module, config: FluxTeaCacheConfig): # Register hook on main transformer registry = HookRegistry.check_if_exists_or_initialize(module) hook = FluxTeaCacheHook(config) - registry.register_hook(hook, "flux_teacache") + registry.register_hook(hook, _FLUX_TEACACHE_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 605c0d588c8c..48055716307f 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -28,6 +28,7 @@ class CacheMixin: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [FasterCache](https://huggingface.co/papers/2410.19355) - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) + - [TeaCache](https://huggingface.co/papers/2411.19108) (FLUX-specific) """ _cache_config = None @@ -66,9 +67,11 @@ def enable_cache(self, config) -> None: from ..hooks import ( FasterCacheConfig, FirstBlockCacheConfig, + FluxTeaCacheConfig, PyramidAttentionBroadcastConfig, apply_faster_cache, apply_first_block_cache, + apply_flux_teacache, apply_pyramid_attention_broadcast, ) @@ -83,15 +86,18 @@ def enable_cache(self, config) -> None: apply_first_block_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) + elif isinstance(config, FluxTeaCacheConfig): + apply_flux_teacache(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, FluxTeaCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK + from ..hooks.flux_teacache import _FLUX_TEACACHE_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK if self._cache_config is None: @@ -107,6 +113,8 @@ def disable_cache(self) -> None: registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, FluxTeaCacheConfig): + registry.remove_hook(_FLUX_TEACACHE_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") @@ -128,3 +136,12 @@ def cache_context(self, name: str): yield registry._set_context(None) + + def enable_flux_teacache(self, rel_l1_thresh: float = 0.2, **kwargs): + r""" + Enable FLUX TeaCache on the model. + """ + from ..hooks import FluxTeaCacheConfig + + config = FluxTeaCacheConfig(rel_l1_thresh=rel_l1_thresh, **kwargs) + self.enable_cache(config) From 29d4ffc45427d82908401a639c667a949821a2b8 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Thu, 13 Nov 2025 17:22:38 +0000 Subject: [PATCH 03/25] change hook to inline transformer processing instead of calling original_forward --- src/diffusers/hooks/flux_teacache.py | 85 +++++++++++++++++++++------- 1 file changed, 64 insertions(+), 21 deletions(-) diff --git a/src/diffusers/hooks/flux_teacache.py b/src/diffusers/hooks/flux_teacache.py index f548d32c7b74..c9cee4b4eb00 100644 --- a/src/diffusers/hooks/flux_teacache.py +++ b/src/diffusers/hooks/flux_teacache.py @@ -71,11 +71,11 @@ def initialize_hook(self, module): self.state_manager.set_context("flux_teacache") return module - def new_forward(self, module, hidden_states, timestep, pooled_projections, + def new_forward(self, module, hidden_states, timestep, pooled_projections, encoder_hidden_states, txt_ids, img_ids, **kwargs): """Replace FLUX transformer forward with TeaCache logic.""" state = self.state_manager.get_state() - + # Reset counter if we've completed all steps (new inference run) if state.cnt == state.num_steps and state.num_steps > 0: state.cnt = 0 @@ -91,33 +91,76 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, if state.num_steps == 0 and hasattr(module, 'num_steps'): state.num_steps = module.num_steps - # Extract timestep embedding for FLUX - timestep = timestep.to(hidden_states.dtype) * 1000 - temb = module.time_text_embed(timestep, pooled_projections) - - # Extract modulated input from first transformer block - modulated_inp, _, _, _, _ = module.transformer_blocks[0].norm1( - hidden_states.clone(), emb=temb.clone() - ) - - # Make caching decision using original algorithm + # Process inputs like original TeaCache + # Must process hidden_states through x_embedder first + hidden_states = module.x_embedder(hidden_states) + + # Extract timestep embedding + timestep_scaled = timestep.to(hidden_states.dtype) * 1000 + if kwargs.get('guidance') is not None: + guidance = kwargs['guidance'].to(hidden_states.dtype) * 1000 + temb = module.time_text_embed(timestep_scaled, guidance, pooled_projections) + else: + temb = module.time_text_embed(timestep_scaled, pooled_projections) + + # Extract modulated input from first transformer block like original + inp = hidden_states.clone() + temb_clone = temb.clone() + modulated_inp, _, _, _, _ = module.transformer_blocks[0].norm1(inp, emb=temb_clone) + + # Make caching decision should_calc = self._should_compute_full_transformer(state, modulated_inp) - + if not should_calc: # Fast path: apply cached residual output = hidden_states + state.previous_residual else: - # Slow path: full computation + # Slow path: full computation inline (like original TeaCache) ori_hidden_states = hidden_states.clone() - output = self.fn_ref.original_forward( - hidden_states, timestep, pooled_projections, - encoder_hidden_states, txt_ids, img_ids, **kwargs - ) + + # Process encoder_hidden_states + encoder_hidden_states = module.context_embedder(encoder_hidden_states) + + # Process txt_ids and img_ids + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = module.pos_embed(ids) + + # Process through transformer blocks + for block in module.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=kwargs.get('joint_attention_kwargs'), + ) + + # Process through single transformer blocks + # Note: single blocks concatenate internally, so pass separately + for block in module.single_transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=kwargs.get('joint_attention_kwargs'), + ) + # Cache the residual - state.previous_residual = output - ori_hidden_states - state.previous_modulated_input = modulated_inp - + state.previous_residual = hidden_states - ori_hidden_states + + state.previous_modulated_input = modulated_inp state.cnt += 1 + + # Apply final norm and projection (always needed) + hidden_states = module.norm_out(hidden_states, temb) + output = module.proj_out(hidden_states) + return output def _should_compute_full_transformer(self, state, modulated_inp): From 07c6718465fcd0f1bb1e9e6f2cfb4cbbe7136625 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Fri, 14 Nov 2025 06:43:17 +0000 Subject: [PATCH 04/25] add extensive docstring (auto gen); add repr --- src/diffusers/hooks/flux_teacache.py | 247 ++++++++++++++++++++++++--- 1 file changed, 223 insertions(+), 24 deletions(-) diff --git a/src/diffusers/hooks/flux_teacache.py b/src/diffusers/hooks/flux_teacache.py index c9cee4b4eb00..6e4253d3fafc 100644 --- a/src/diffusers/hooks/flux_teacache.py +++ b/src/diffusers/hooks/flux_teacache.py @@ -26,41 +26,167 @@ @dataclass class FluxTeaCacheConfig: - """Configuration for FLUX TeaCache following original algorithm.""" - rel_l1_thresh: float = 0.2 # threshold for accumulated distance (based on paper 0.1->0.3 works best) - coefficients: Optional[List[float]] = None # FLUX-specific polynomial coefficients + r""" + Configuration for [TeaCache](https://liewfeng.github.io/TeaCache/) applied to FLUX models. + + TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model + inference by reusing transformer block computations when consecutive timestep embeddings are similar. It uses + polynomial rescaling of L1 distances between modulated inputs to intelligently decide when to cache. + + Args: + rel_l1_thresh (`float`, defaults to `0.2`): + Threshold for accumulated relative L1 distance. When the accumulated distance is below this threshold, + the cached residual from the previous timestep is reused instead of computing the full transformer. + Based on the original TeaCache paper, values in the range [0.1, 0.3] work best for balancing speed + and quality: + - 0.25 for ~1.5x speedup with minimal quality loss + - 0.4 for ~1.8x speedup with slight quality loss + - 0.6 for ~2.0x speedup with noticeable quality loss + - 0.8 for ~2.25x speedup with significant quality loss + Higher thresholds lead to more aggressive caching and faster inference, but may reduce output quality. + coefficients (`List[float]`, *optional*, defaults to FLUX-specific polynomial coefficients): + FLUX-specific polynomial coefficients used for rescaling the raw L1 distance. These coefficients + transform the relative L1 distance into a model-specific caching signal. If not provided, defaults + to the coefficients determined for FLUX models in the TeaCache paper: + [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]. + The polynomial is evaluated as: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]` where x is the + relative L1 distance. + current_timestep_callback (`Callable[[], int]`, *optional*, defaults to `None`): + Callback function that returns the current timestep during inference. This is used internally for + debugging and statistics tracking. If not provided, TeaCache will still function correctly. + num_inference_steps_callback (`Callable[[], int]`, *optional*, defaults to `None`): + Callback function that returns the total number of inference steps. This is used to ensure the first + and last timesteps are always computed (never cached) for maximum quality. If not provided, TeaCache + will attempt to detect the number of steps automatically from the pipeline. + + Examples: + ```python + from diffusers import FluxPipeline + from diffusers.hooks import FluxTeaCacheConfig + + # Load FLUX pipeline + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + pipe.to("cuda") + + # Enable TeaCache with default settings (1.5x speedup) + config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + pipe.transformer.enable_cache(config) + + # Generate image with caching + image = pipe("A cat sitting on a windowsill", num_inference_steps=4).images[0] + + # Disable caching + pipe.transformer.disable_cache() + + # For more aggressive caching (2x speedup, slight quality loss) + config = FluxTeaCacheConfig(rel_l1_thresh=0.6) + pipe.transformer.enable_cache(config) + image = pipe("A cat sitting on a windowsill", num_inference_steps=4).images[0] + ``` + """ + rel_l1_thresh: float = 0.2 + coefficients: Optional[List[float]] = None current_timestep_callback: Optional[Callable[[], int]] = None - num_inference_steps_callback: Optional[Callable[[], int]] = None # Callback to get total inference steps - + num_inference_steps_callback: Optional[Callable[[], int]] = None + def __post_init__(self): if self.coefficients is None: - # original FLUX coefficients from TeaCache paper - self.coefficients = [4.98651651e+02, -2.83781631e+02, + # Original FLUX coefficients from TeaCache paper + self.coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + def __repr__(self) -> str: + return ( + f"FluxTeaCacheConfig(\n" + f" rel_l1_thresh={self.rel_l1_thresh},\n" + f" coefficients={self.coefficients},\n" + f" current_timestep_callback={self.current_timestep_callback},\n" + f" num_inference_steps_callback={self.num_inference_steps_callback}\n" + f")" + ) + class FluxTeaCacheState(BaseState): - """State management following original TeaCache implementation.""" + """ + State management for FLUX TeaCache hook. + + This class tracks the caching state across diffusion timesteps, managing counters, accumulated distances, + and cached values needed for the TeaCache algorithm. The state persists across multiple forward passes + during a single inference run and is automatically reset when a new inference begins. + + Attributes: + cnt (int): + Current timestep counter, incremented with each forward pass. Used to identify first/last timesteps + which are always computed (never cached) for maximum quality. + num_steps (int): + Total number of inference steps for the current run. Used to identify the last timestep. Automatically + detected from callbacks or pipeline attributes if not explicitly set. + accumulated_rel_l1_distance (float): + Running accumulator for rescaled L1 distances between consecutive modulated inputs. Compared against + the threshold to make caching decisions. Reset to 0 when the decision is made to recompute. + previous_modulated_input (torch.Tensor): + Modulated input from the previous timestep, extracted from the first transformer block's norm1 layer. + Used for computing L1 distance to determine similarity between consecutive timesteps. + previous_residual (torch.Tensor): + Cached residual (output - input) from the previous timestep's full transformer computation. Applied + directly when caching is triggered instead of computing all transformer blocks. + """ def __init__(self): - self.cnt = 0 # Current timestep counter - self.num_steps = 0 # Total inference steps - self.accumulated_rel_l1_distance = 0.0 # Running accumulator - self.previous_modulated_input = None # Previous timestep modulated features - self.previous_residual = None # cached transformer residual - + self.cnt = 0 + self.num_steps = 0 + self.accumulated_rel_l1_distance = 0.0 + self.previous_modulated_input = None + self.previous_residual = None + def reset(self): + """Reset all state variables to initial values for a new inference run.""" self.cnt = 0 self.num_steps = 0 self.accumulated_rel_l1_distance = 0.0 - self.previous_modulated_input = None + self.previous_modulated_input = None self.previous_residual = None + def __repr__(self) -> str: + return ( + f"FluxTeaCacheState(\n" + f" cnt={self.cnt},\n" + f" num_steps={self.num_steps},\n" + f" accumulated_rel_l1_distance={self.accumulated_rel_l1_distance:.6f},\n" + f" previous_modulated_input={'cached' if self.previous_modulated_input is not None else 'None'},\n" + f" previous_residual={'cached' if self.previous_residual is not None else 'None'}\n" + f")" + ) + class FluxTeaCacheHook(ModelHook): - """Main hook implementing FLUX TeaCache logic.""" - + """ + ModelHook implementing TeaCache for FLUX transformer models. + + This hook intercepts the FLUX transformer forward pass and implements adaptive caching based on timestep + embedding similarity. It extracts modulated inputs from the first transformer block, computes L1 distances, + applies polynomial rescaling, and decides whether to reuse cached residuals or compute full transformer blocks. + + The hook follows the original TeaCache algorithm from the paper: + 1. Extract modulated input from first transformer block's norm1 layer with timestep embedding + 2. Compute relative L1 distance between current and previous modulated inputs + 3. Apply polynomial rescaling with FLUX-specific coefficients to the distance + 4. Accumulate rescaled distances and compare to threshold + 5. If below threshold: reuse cached residual (fast path, skip transformer computation) + 6. If above threshold: compute full transformer blocks and cache new residual (slow path) + + The first and last timesteps are always computed fully (never cached) to ensure maximum quality. + + Attributes: + config (FluxTeaCacheConfig): + Configuration containing threshold, polynomial coefficients, and optional callbacks. + rescale_func (np.poly1d): + Polynomial function for rescaling L1 distances using FLUX-specific coefficients. + state_manager (StateManager): + Manages FluxTeaCacheState across forward passes, maintaining counters and cached values. + """ + _is_stateful = True - + def __init__(self, config: FluxTeaCacheConfig): super().__init__() self.config = config @@ -73,7 +199,26 @@ def initialize_hook(self, module): def new_forward(self, module, hidden_states, timestep, pooled_projections, encoder_hidden_states, txt_ids, img_ids, **kwargs): - """Replace FLUX transformer forward with TeaCache logic.""" + """ + Replace FLUX transformer forward pass with TeaCache-enabled version. + + This method implements the full TeaCache algorithm inline, processing transformer blocks directly instead + of calling the original forward method. It extracts modulated inputs, makes caching decisions, and either + applies cached residuals (fast path) or computes full transformer blocks (slow path). + + Args: + module: The FluxTransformer2DModel instance. + hidden_states (`torch.Tensor`): Input latent tensor of shape (batch, channels, height, width). + timestep (`torch.Tensor`): Current diffusion timestep. + pooled_projections (`torch.Tensor`): Pooled text embeddings for timestep conditioning. + encoder_hidden_states (`torch.Tensor`): Text encoder outputs for cross-attention. + txt_ids (`torch.Tensor`): Position IDs for text tokens. + img_ids (`torch.Tensor`): Position IDs for image tokens. + **kwargs: Additional arguments including 'guidance' and 'joint_attention_kwargs'. + + Returns: + `torch.Tensor`: Denoised output tensor. + """ state = self.state_manager.get_state() # Reset counter if we've completed all steps (new inference run) @@ -164,8 +309,24 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, return output def _should_compute_full_transformer(self, state, modulated_inp): - """Core caching decision logic from original TeaCache.""" - # compute first timestep + """ + Determine whether to compute full transformer blocks or reuse cached residual. + + This method implements the core caching decision logic from the TeaCache paper: + - Always compute first and last timesteps (for maximum quality) + - For intermediate timesteps, compute relative L1 distance between current and previous modulated inputs + - Apply polynomial rescaling to convert distance to model-specific caching signal + - Accumulate rescaled distances and compare to threshold + - Return True (compute) if accumulated distance exceeds threshold, False (cache) otherwise + + Args: + state (`FluxTeaCacheState`): Current state containing counters and cached values. + modulated_inp (`torch.Tensor`): Modulated input from first transformer block's norm1 layer. + + Returns: + `bool`: True to compute full transformer, False to reuse cached residual. + """ + # Compute first timestep if state.cnt == 0: state.accumulated_rel_l1_distance = 0 return True @@ -200,13 +361,51 @@ def reset_state(self, module): def apply_flux_teacache(module, config: FluxTeaCacheConfig): - """Apply TeaCache to FLUX transformer following diffusers patterns.""" + """ + Apply TeaCache optimization to a FLUX transformer model. + + This function registers a FluxTeaCacheHook on the provided FLUX transformer, enabling adaptive caching of + transformer block computations based on timestep embedding similarity. The hook intercepts the forward pass + and implements the TeaCache algorithm to achieve 1.5x-2x speedup with minimal quality loss. + + Args: + module: The FLUX transformer model (FluxTransformer2DModel) to optimize. + config (`FluxTeaCacheConfig`): Configuration specifying caching threshold and optional callbacks. + + Raises: + ValueError: If the module is not a FluxTransformer2DModel. + + Examples: + ```python + from diffusers import FluxPipeline + from diffusers.hooks import FluxTeaCacheConfig, apply_flux_teacache + + # Load FLUX pipeline + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + pipe.to("cuda") + + # Apply TeaCache directly to transformer + config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + apply_flux_teacache(pipe.transformer, config) + + # Generate with caching enabled + image = pipe("A cat on a windowsill", num_inference_steps=4).images[0] + + # Or use the convenience method via CacheMixin + pipe.transformer.enable_cache(config) + ``` + + Note: + For most use cases, it's recommended to use the CacheMixin interface: + `pipe.transformer.enable_cache(FluxTeaCacheConfig(...))` which provides additional convenience methods + like `disable_cache()` for easy toggling. + """ from ..models.transformers.transformer_flux import FluxTransformer2DModel - + # Validate FLUX model if not isinstance(module, FluxTransformer2DModel): raise ValueError("TeaCache supports only FLUX transformer model for now") - + # Register hook on main transformer registry = HookRegistry.check_if_exists_or_initialize(module) hook = FluxTeaCacheHook(config) From a7598a12cfa2cb04e4c058aa0d71ff2a40285c5b Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Fri, 14 Nov 2025 06:52:59 +0000 Subject: [PATCH 05/25] add param validation and error messages --- src/diffusers/hooks/flux_teacache.py | 52 +++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/flux_teacache.py b/src/diffusers/hooks/flux_teacache.py index 6e4253d3fafc..4110f71d6689 100644 --- a/src/diffusers/hooks/flux_teacache.py +++ b/src/diffusers/hooks/flux_teacache.py @@ -90,11 +90,57 @@ class FluxTeaCacheConfig: num_inference_steps_callback: Optional[Callable[[], int]] = None def __post_init__(self): + # Validate rel_l1_thresh + if not isinstance(self.rel_l1_thresh, (int, float)): + raise TypeError( + f"rel_l1_thresh must be a number, got {type(self.rel_l1_thresh).__name__}. " + f"Please provide a float value between 0.1 and 1.0." + ) + if self.rel_l1_thresh <= 0: + raise ValueError( + f"rel_l1_thresh must be positive, got {self.rel_l1_thresh}. " + f"Based on the TeaCache paper, values between 0.1 and 0.3 work best. " + f"Try 0.25 for 1.5x speedup or 0.6 for 2x speedup." + ) + if self.rel_l1_thresh < 0.05: + import warnings + warnings.warn( + f"rel_l1_thresh={self.rel_l1_thresh} is very low and may result in minimal caching. " + f"Consider using values between 0.1 and 0.3 for optimal performance.", + UserWarning + ) + if self.rel_l1_thresh > 1.0: + import warnings + warnings.warn( + f"rel_l1_thresh={self.rel_l1_thresh} is very high and may cause quality degradation. " + f"Consider using values between 0.1 and 0.6 for better quality-speed tradeoff.", + UserWarning + ) + + # Set default coefficients if not provided if self.coefficients is None: # Original FLUX coefficients from TeaCache paper self.coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + # Validate coefficients + if not isinstance(self.coefficients, (list, tuple)): + raise TypeError( + f"coefficients must be a list or tuple, got {type(self.coefficients).__name__}. " + f"Please provide a list of 5 polynomial coefficients." + ) + if len(self.coefficients) != 5: + raise ValueError( + f"coefficients must contain exactly 5 elements for 4th-degree polynomial, " + f"got {len(self.coefficients)}. The polynomial is evaluated as: " + f"c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]" + ) + if not all(isinstance(c, (int, float)) for c in self.coefficients): + raise TypeError( + f"All coefficients must be numbers. " + f"Got types: {[type(c).__name__ for c in self.coefficients]}" + ) + def __repr__(self) -> str: return ( f"FluxTeaCacheConfig(\n" @@ -404,7 +450,11 @@ def apply_flux_teacache(module, config: FluxTeaCacheConfig): # Validate FLUX model if not isinstance(module, FluxTransformer2DModel): - raise ValueError("TeaCache supports only FLUX transformer model for now") + raise ValueError( + f"TeaCache currently supports only FLUX transformer models. " + f"Got {type(module).__name__}. Please ensure you're applying TeaCache to a " + f"FluxTransformer2DModel instance (e.g., pipe.transformer)." + ) # Register hook on main transformer registry = HookRegistry.check_if_exists_or_initialize(module) From d9648e56574e0cd8dc30044033d0ac94d62068c4 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Fri, 14 Nov 2025 07:17:32 +0000 Subject: [PATCH 06/25] add basic logging --- src/diffusers/hooks/flux_teacache.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/diffusers/hooks/flux_teacache.py b/src/diffusers/hooks/flux_teacache.py index 4110f71d6689..c31723d99803 100644 --- a/src/diffusers/hooks/flux_teacache.py +++ b/src/diffusers/hooks/flux_teacache.py @@ -18,9 +18,12 @@ import numpy as np import torch +from ..utils import logging from .hooks import BaseState, HookRegistry, ModelHook, StateManager +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + _FLUX_TEACACHE_HOOK = "flux_teacache" @@ -269,6 +272,7 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, # Reset counter if we've completed all steps (new inference run) if state.cnt == state.num_steps and state.num_steps > 0: + logger.info("TeaCache inference completed") state.cnt = 0 state.accumulated_rel_l1_distance = 0.0 state.previous_modulated_input = None @@ -304,9 +308,17 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, if not should_calc: # Fast path: apply cached residual + logger.debug( + f"TeaCache: reusing cached residual at step {state.cnt}/{state.num_steps} " + f"(accumulated distance: {state.accumulated_rel_l1_distance:.6f})" + ) output = hidden_states + state.previous_residual else: # Slow path: full computation inline (like original TeaCache) + logger.debug( + f"TeaCache: computing full transformer at step {state.cnt}/{state.num_steps} " + f"(accumulated distance: {state.accumulated_rel_l1_distance:.6f})" + ) ori_hidden_states = hidden_states.clone() # Process encoder_hidden_states From 59cb890e57544c3b7f2842bddc939675be4cd7b6 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Fri, 14 Nov 2025 08:21:02 +0000 Subject: [PATCH 07/25] add compatible test --- tests/hooks/test_flux_teacache.py | 167 ++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 tests/hooks/test_flux_teacache.py diff --git a/tests/hooks/test_flux_teacache.py b/tests/hooks/test_flux_teacache.py new file mode 100644 index 000000000000..1a2bcd613520 --- /dev/null +++ b/tests/hooks/test_flux_teacache.py @@ -0,0 +1,167 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import warnings + +import torch + +from diffusers.hooks import FluxTeaCacheConfig, HookRegistry + + +class FluxTeaCacheConfigTests(unittest.TestCase): + """Tests for FluxTeaCacheConfig parameter validation.""" + + def test_valid_config(self): + """Test valid configuration is accepted.""" + config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + self.assertEqual(config.rel_l1_thresh, 0.2) + self.assertIsNotNone(config.coefficients) + self.assertEqual(len(config.coefficients), 5) + + def test_invalid_type(self): + """Test invalid type for rel_l1_thresh raises TypeError.""" + with self.assertRaises(TypeError) as context: + FluxTeaCacheConfig(rel_l1_thresh="invalid") + self.assertIn("must be a number", str(context.exception)) + + def test_negative_value(self): + """Test negative threshold raises ValueError.""" + with self.assertRaises(ValueError) as context: + FluxTeaCacheConfig(rel_l1_thresh=-0.5) + self.assertIn("must be positive", str(context.exception)) + + def test_invalid_coefficients_length(self): + """Test wrong coefficient count raises ValueError.""" + with self.assertRaises(ValueError) as context: + FluxTeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, 3.0]) + self.assertIn("exactly 5 elements", str(context.exception)) + + def test_invalid_coefficients_type(self): + """Test invalid coefficient types raise TypeError.""" + with self.assertRaises(TypeError) as context: + FluxTeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, "invalid", 4.0, 5.0]) + self.assertIn("must be numbers", str(context.exception)) + + def test_warning_very_low_threshold(self): + """Test warning is issued for very low threshold.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + FluxTeaCacheConfig(rel_l1_thresh=0.01) + self.assertEqual(len(w), 1) + self.assertIn("very low", str(w[0].message)) + + def test_warning_very_high_threshold(self): + """Test warning is issued for very high threshold.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + FluxTeaCacheConfig(rel_l1_thresh=1.5) + self.assertEqual(len(w), 1) + self.assertIn("very high", str(w[0].message)) + + def test_config_repr(self): + """Test __repr__ method works correctly.""" + config = FluxTeaCacheConfig(rel_l1_thresh=0.25) + repr_str = repr(config) + self.assertIn("FluxTeaCacheConfig", repr_str) + self.assertIn("0.25", repr_str) + + def test_custom_coefficients(self): + """Test custom coefficients are accepted.""" + custom_coeffs = [1.0, 2.0, 3.0, 4.0, 5.0] + config = FluxTeaCacheConfig(rel_l1_thresh=0.2, coefficients=custom_coeffs) + self.assertEqual(config.coefficients, custom_coeffs) + + +class FluxTeaCacheStateTests(unittest.TestCase): + """Tests for FluxTeaCacheState.""" + + def test_state_initialization(self): + """Test state initializes with correct default values.""" + from diffusers.hooks.flux_teacache import FluxTeaCacheState + + state = FluxTeaCacheState() + self.assertEqual(state.cnt, 0) + self.assertEqual(state.num_steps, 0) + self.assertEqual(state.accumulated_rel_l1_distance, 0.0) + self.assertIsNone(state.previous_modulated_input) + self.assertIsNone(state.previous_residual) + + def test_state_reset(self): + """Test state reset clears all values.""" + from diffusers.hooks.flux_teacache import FluxTeaCacheState + + state = FluxTeaCacheState() + # Modify state + state.cnt = 5 + state.num_steps = 10 + state.accumulated_rel_l1_distance = 0.5 + state.previous_modulated_input = torch.randn(1, 10) + state.previous_residual = torch.randn(1, 10) + + # Reset + state.reset() + + # Verify reset + self.assertEqual(state.cnt, 0) + self.assertEqual(state.num_steps, 0) + self.assertEqual(state.accumulated_rel_l1_distance, 0.0) + self.assertIsNone(state.previous_modulated_input) + self.assertIsNone(state.previous_residual) + + def test_state_repr(self): + """Test __repr__ method works correctly.""" + from diffusers.hooks.flux_teacache import FluxTeaCacheState + + state = FluxTeaCacheState() + state.cnt = 3 + state.num_steps = 10 + repr_str = repr(state) + self.assertIn("FluxTeaCacheState", repr_str) + self.assertIn("cnt=3", repr_str) + self.assertIn("num_steps=10", repr_str) + + +class FluxTeaCacheHookTests(unittest.TestCase): + """Tests for FluxTeaCacheHook functionality.""" + + def test_hook_initialization(self): + """Test hook initializes correctly with config.""" + from diffusers.hooks.flux_teacache import FluxTeaCacheHook + + config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + hook = FluxTeaCacheHook(config) + + self.assertEqual(hook.config.rel_l1_thresh, 0.2) + self.assertIsNotNone(hook.rescale_func) + self.assertIsNotNone(hook.state_manager) + + def test_apply_flux_teacache_validation(self): + """Test apply_flux_teacache validates input module type.""" + from diffusers.hooks import apply_flux_teacache + + # Create a dummy module that's not a FluxTransformer2DModel + class DummyModule(torch.nn.Module): + pass + + module = DummyModule() + config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + + with self.assertRaises(ValueError) as context: + apply_flux_teacache(module, config) + self.assertIn("FLUX transformer models", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() From 44663de92bc7bd7c6547a3b37508a431ca154e29 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Wed, 26 Nov 2025 08:52:26 +0000 Subject: [PATCH 08/25] update it to make it model agnostic --- src/diffusers/hooks/__init__.py | 2 +- .../hooks/{flux_teacache.py => teacache.py} | 128 ++++++++++-------- src/diffusers/models/cache_utils.py | 26 ++-- ...test_flux_teacache.py => test_teacache.py} | 73 +++++----- 4 files changed, 121 insertions(+), 108 deletions(-) rename src/diffusers/hooks/{flux_teacache.py => teacache.py} (82%) rename tests/hooks/{test_flux_teacache.py => test_teacache.py} (68%) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index ecec7322fac9..3ccd0056840f 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -19,7 +19,7 @@ from .context_parallel import apply_context_parallel from .faster_cache import FasterCacheConfig, apply_faster_cache from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache - from .flux_teacache import FluxTeaCacheConfig, apply_flux_teacache + from .teacache import TeaCacheConfig, apply_teacache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip diff --git a/src/diffusers/hooks/flux_teacache.py b/src/diffusers/hooks/teacache.py similarity index 82% rename from src/diffusers/hooks/flux_teacache.py rename to src/diffusers/hooks/teacache.py index c31723d99803..c0fa7bd9db28 100644 --- a/src/diffusers/hooks/flux_teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -24,13 +24,27 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -_FLUX_TEACACHE_HOOK = "flux_teacache" +_TEACACHE_HOOK = "teacache" + + +def _flux_modulated_input_extractor(module, hidden_states, timestep_emb): + """Extract modulated input for FLUX models.""" + return module.transformer_blocks[0].norm1(hidden_states, emb=timestep_emb)[0] + + +def _auto_detect_extractor(module): + """Auto-detect and return appropriate extractor based on model type.""" + module_class_name = module.__class__.__name__ + if "Flux" in module_class_name: + return _flux_modulated_input_extractor + # Add more model types as needed + return _flux_modulated_input_extractor # Default to FLUX for now @dataclass -class FluxTeaCacheConfig: +class TeaCacheConfig: r""" - Configuration for [TeaCache](https://liewfeng.github.io/TeaCache/) applied to FLUX models. + Configuration for [TeaCache](https://liewfeng.github.io/TeaCache/) applied to transformer models. TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model inference by reusing transformer block computations when consecutive timestep embeddings are similar. It uses @@ -47,13 +61,16 @@ class FluxTeaCacheConfig: - 0.6 for ~2.0x speedup with noticeable quality loss - 0.8 for ~2.25x speedup with significant quality loss Higher thresholds lead to more aggressive caching and faster inference, but may reduce output quality. - coefficients (`List[float]`, *optional*, defaults to FLUX-specific polynomial coefficients): - FLUX-specific polynomial coefficients used for rescaling the raw L1 distance. These coefficients + coefficients (`List[float]`, *optional*, defaults to polynomial coefficients from TeaCache paper): + Polynomial coefficients used for rescaling the raw L1 distance. These coefficients transform the relative L1 distance into a model-specific caching signal. If not provided, defaults to the coefficients determined for FLUX models in the TeaCache paper: [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]. The polynomial is evaluated as: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]` where x is the relative L1 distance. + extract_modulated_input_fn (`Callable`, *optional*, defaults to auto-detection): + Function to extract modulated input from the transformer module. Takes (module, hidden_states, timestep_emb) + and returns the modulated input tensor. If not provided, auto-detects based on model type. current_timestep_callback (`Callable[[], int]`, *optional*, defaults to `None`): Callback function that returns the current timestep during inference. This is used internally for debugging and statistics tracking. If not provided, TeaCache will still function correctly. @@ -65,14 +82,17 @@ class FluxTeaCacheConfig: Examples: ```python from diffusers import FluxPipeline - from diffusers.hooks import FluxTeaCacheConfig + from diffusers.hooks import TeaCacheConfig # Load FLUX pipeline pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) pipe.to("cuda") - # Enable TeaCache with default settings (1.5x speedup) - config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + # Enable TeaCache with auto-detection (1.5x speedup) + pipe.transformer.enable_teacache(rel_l1_thresh=0.2) + + # Or with explicit config + config = TeaCacheConfig(rel_l1_thresh=0.2) pipe.transformer.enable_cache(config) # Generate image with caching @@ -80,15 +100,11 @@ class FluxTeaCacheConfig: # Disable caching pipe.transformer.disable_cache() - - # For more aggressive caching (2x speedup, slight quality loss) - config = FluxTeaCacheConfig(rel_l1_thresh=0.6) - pipe.transformer.enable_cache(config) - image = pipe("A cat sitting on a windowsill", num_inference_steps=4).images[0] ``` """ rel_l1_thresh: float = 0.2 coefficients: Optional[List[float]] = None + extract_modulated_input_fn: Optional[Callable] = None current_timestep_callback: Optional[Callable[[], int]] = None num_inference_steps_callback: Optional[Callable[[], int]] = None @@ -146,18 +162,19 @@ def __post_init__(self): def __repr__(self) -> str: return ( - f"FluxTeaCacheConfig(\n" + f"TeaCacheConfig(\n" f" rel_l1_thresh={self.rel_l1_thresh},\n" f" coefficients={self.coefficients},\n" + f" extract_modulated_input_fn={self.extract_modulated_input_fn},\n" f" current_timestep_callback={self.current_timestep_callback},\n" f" num_inference_steps_callback={self.num_inference_steps_callback}\n" f")" ) -class FluxTeaCacheState(BaseState): +class TeaCacheState(BaseState): """ - State management for FLUX TeaCache hook. + State management for TeaCache hook. This class tracks the caching state across diffusion timesteps, managing counters, accumulated distances, and cached values needed for the TeaCache algorithm. The state persists across multiple forward passes @@ -197,7 +214,7 @@ def reset(self): def __repr__(self) -> str: return ( - f"FluxTeaCacheState(\n" + f"TeaCacheState(\n" f" cnt={self.cnt},\n" f" num_steps={self.num_steps},\n" f" accumulated_rel_l1_distance={self.accumulated_rel_l1_distance:.6f},\n" @@ -207,18 +224,18 @@ def __repr__(self) -> str: ) -class FluxTeaCacheHook(ModelHook): +class TeaCacheHook(ModelHook): """ - ModelHook implementing TeaCache for FLUX transformer models. + ModelHook implementing TeaCache for transformer models. - This hook intercepts the FLUX transformer forward pass and implements adaptive caching based on timestep - embedding similarity. It extracts modulated inputs from the first transformer block, computes L1 distances, + This hook intercepts transformer forward pass and implements adaptive caching based on timestep + embedding similarity. It extracts modulated inputs, computes L1 distances, applies polynomial rescaling, and decides whether to reuse cached residuals or compute full transformer blocks. The hook follows the original TeaCache algorithm from the paper: - 1. Extract modulated input from first transformer block's norm1 layer with timestep embedding + 1. Extract modulated input using provided extractor function 2. Compute relative L1 distance between current and previous modulated inputs - 3. Apply polynomial rescaling with FLUX-specific coefficients to the distance + 3. Apply polynomial rescaling with model-specific coefficients to the distance 4. Accumulate rescaled distances and compare to threshold 5. If below threshold: reuse cached residual (fast path, skip transformer computation) 6. If above threshold: compute full transformer blocks and cache new residual (slow path) @@ -226,24 +243,30 @@ class FluxTeaCacheHook(ModelHook): The first and last timesteps are always computed fully (never cached) to ensure maximum quality. Attributes: - config (FluxTeaCacheConfig): + config (TeaCacheConfig): Configuration containing threshold, polynomial coefficients, and optional callbacks. rescale_func (np.poly1d): - Polynomial function for rescaling L1 distances using FLUX-specific coefficients. + Polynomial function for rescaling L1 distances using model-specific coefficients. state_manager (StateManager): - Manages FluxTeaCacheState across forward passes, maintaining counters and cached values. + Manages TeaCacheState across forward passes, maintaining counters and cached values. """ _is_stateful = True - def __init__(self, config: FluxTeaCacheConfig): + def __init__(self, config: TeaCacheConfig): super().__init__() self.config = config self.rescale_func = np.poly1d(config.coefficients) - self.state_manager = StateManager(FluxTeaCacheState, (), {}) + self.state_manager = StateManager(TeaCacheState, (), {}) + self.extractor_fn = None def initialize_hook(self, module): - self.state_manager.set_context("flux_teacache") + self.state_manager.set_context("teacache") + # Auto-detect extractor if not provided + if self.config.extract_modulated_input_fn is None: + self.extractor_fn = _auto_detect_extractor(module) + else: + self.extractor_fn = self.config.extract_modulated_input_fn return module def new_forward(self, module, hidden_states, timestep, pooled_projections, @@ -298,10 +321,10 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, else: temb = module.time_text_embed(timestep_scaled, pooled_projections) - # Extract modulated input from first transformer block like original + # Extract modulated input using configured extractor inp = hidden_states.clone() temb_clone = temb.clone() - modulated_inp, _, _, _, _ = module.transformer_blocks[0].norm1(inp, emb=temb_clone) + modulated_inp = self.extractor_fn(module, inp, temb_clone) # Make caching decision should_calc = self._should_compute_full_transformer(state, modulated_inp) @@ -312,7 +335,7 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, f"TeaCache: reusing cached residual at step {state.cnt}/{state.num_steps} " f"(accumulated distance: {state.accumulated_rel_l1_distance:.6f})" ) - output = hidden_states + state.previous_residual + hidden_states = hidden_states + state.previous_residual else: # Slow path: full computation inline (like original TeaCache) logger.debug( @@ -378,8 +401,8 @@ def _should_compute_full_transformer(self, state, modulated_inp): - Return True (compute) if accumulated distance exceeds threshold, False (cache) otherwise Args: - state (`FluxTeaCacheState`): Current state containing counters and cached values. - modulated_inp (`torch.Tensor`): Modulated input from first transformer block's norm1 layer. + state (`TeaCacheState`): Current state containing counters and cached values. + modulated_inp (`torch.Tensor`): Modulated input extracted using configured extractor function. Returns: `bool`: True to compute full transformer, False to reuse cached residual. @@ -418,57 +441,44 @@ def reset_state(self, module): return module -def apply_flux_teacache(module, config: FluxTeaCacheConfig): +def apply_teacache(module, config: TeaCacheConfig): """ - Apply TeaCache optimization to a FLUX transformer model. + Apply TeaCache optimization to a transformer model. - This function registers a FluxTeaCacheHook on the provided FLUX transformer, enabling adaptive caching of + This function registers a TeaCacheHook on the provided transformer, enabling adaptive caching of transformer block computations based on timestep embedding similarity. The hook intercepts the forward pass and implements the TeaCache algorithm to achieve 1.5x-2x speedup with minimal quality loss. Args: - module: The FLUX transformer model (FluxTransformer2DModel) to optimize. - config (`FluxTeaCacheConfig`): Configuration specifying caching threshold and optional callbacks. - - Raises: - ValueError: If the module is not a FluxTransformer2DModel. + module: The transformer model to optimize (e.g., FluxTransformer2DModel, CogVideoXTransformer3DModel). + config (`TeaCacheConfig`): Configuration specifying caching threshold and optional callbacks. Examples: ```python from diffusers import FluxPipeline - from diffusers.hooks import FluxTeaCacheConfig, apply_flux_teacache + from diffusers.hooks import TeaCacheConfig, apply_teacache # Load FLUX pipeline pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) pipe.to("cuda") # Apply TeaCache directly to transformer - config = FluxTeaCacheConfig(rel_l1_thresh=0.2) - apply_flux_teacache(pipe.transformer, config) + config = TeaCacheConfig(rel_l1_thresh=0.2) + apply_teacache(pipe.transformer, config) # Generate with caching enabled image = pipe("A cat on a windowsill", num_inference_steps=4).images[0] # Or use the convenience method via CacheMixin - pipe.transformer.enable_cache(config) + pipe.transformer.enable_teacache(rel_l1_thresh=0.2) ``` Note: For most use cases, it's recommended to use the CacheMixin interface: - `pipe.transformer.enable_cache(FluxTeaCacheConfig(...))` which provides additional convenience methods + `pipe.transformer.enable_teacache(...)` which provides additional convenience methods like `disable_cache()` for easy toggling. """ - from ..models.transformers.transformer_flux import FluxTransformer2DModel - - # Validate FLUX model - if not isinstance(module, FluxTransformer2DModel): - raise ValueError( - f"TeaCache currently supports only FLUX transformer models. " - f"Got {type(module).__name__}. Please ensure you're applying TeaCache to a " - f"FluxTransformer2DModel instance (e.g., pipe.transformer)." - ) - # Register hook on main transformer registry = HookRegistry.check_if_exists_or_initialize(module) - hook = FluxTeaCacheHook(config) - registry.register_hook(hook, _FLUX_TEACACHE_HOOK) + hook = TeaCacheHook(config) + registry.register_hook(hook, _TEACACHE_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 48055716307f..83c595b60507 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -28,7 +28,7 @@ class CacheMixin: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [FasterCache](https://huggingface.co/papers/2410.19355) - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) - - [TeaCache](https://huggingface.co/papers/2411.19108) (FLUX-specific) + - [TeaCache](https://huggingface.co/papers/2411.19108) """ _cache_config = None @@ -67,12 +67,12 @@ def enable_cache(self, config) -> None: from ..hooks import ( FasterCacheConfig, FirstBlockCacheConfig, - FluxTeaCacheConfig, PyramidAttentionBroadcastConfig, + TeaCacheConfig, apply_faster_cache, apply_first_block_cache, - apply_flux_teacache, apply_pyramid_attention_broadcast, + apply_teacache, ) if self.is_cache_enabled: @@ -86,19 +86,19 @@ def enable_cache(self, config) -> None: apply_first_block_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) - elif isinstance(config, FluxTeaCacheConfig): - apply_flux_teacache(self, config) + elif isinstance(config, TeaCacheConfig): + apply_teacache(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, FluxTeaCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TeaCacheConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK - from ..hooks.flux_teacache import _FLUX_TEACACHE_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK + from ..hooks.teacache import _TEACACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -113,8 +113,8 @@ def disable_cache(self) -> None: registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) - elif isinstance(self._cache_config, FluxTeaCacheConfig): - registry.remove_hook(_FLUX_TEACACHE_HOOK, recurse=True) + elif isinstance(self._cache_config, TeaCacheConfig): + registry.remove_hook(_TEACACHE_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") @@ -137,11 +137,11 @@ def cache_context(self, name: str): registry._set_context(None) - def enable_flux_teacache(self, rel_l1_thresh: float = 0.2, **kwargs): + def enable_teacache(self, rel_l1_thresh: float = 0.2, **kwargs): r""" - Enable FLUX TeaCache on the model. + Enable TeaCache on the model. """ - from ..hooks import FluxTeaCacheConfig + from ..hooks import TeaCacheConfig - config = FluxTeaCacheConfig(rel_l1_thresh=rel_l1_thresh, **kwargs) + config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh, **kwargs) self.enable_cache(config) diff --git a/tests/hooks/test_flux_teacache.py b/tests/hooks/test_teacache.py similarity index 68% rename from tests/hooks/test_flux_teacache.py rename to tests/hooks/test_teacache.py index 1a2bcd613520..5fc061d374ff 100644 --- a/tests/hooks/test_flux_teacache.py +++ b/tests/hooks/test_teacache.py @@ -17,15 +17,15 @@ import torch -from diffusers.hooks import FluxTeaCacheConfig, HookRegistry +from diffusers.hooks import TeaCacheConfig, HookRegistry -class FluxTeaCacheConfigTests(unittest.TestCase): - """Tests for FluxTeaCacheConfig parameter validation.""" +class TeaCacheConfigTests(unittest.TestCase): + """Tests for TeaCacheConfig parameter validation.""" def test_valid_config(self): """Test valid configuration is accepted.""" - config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + config = TeaCacheConfig(rel_l1_thresh=0.2) self.assertEqual(config.rel_l1_thresh, 0.2) self.assertIsNotNone(config.coefficients) self.assertEqual(len(config.coefficients), 5) @@ -33,32 +33,32 @@ def test_valid_config(self): def test_invalid_type(self): """Test invalid type for rel_l1_thresh raises TypeError.""" with self.assertRaises(TypeError) as context: - FluxTeaCacheConfig(rel_l1_thresh="invalid") + TeaCacheConfig(rel_l1_thresh="invalid") self.assertIn("must be a number", str(context.exception)) def test_negative_value(self): """Test negative threshold raises ValueError.""" with self.assertRaises(ValueError) as context: - FluxTeaCacheConfig(rel_l1_thresh=-0.5) + TeaCacheConfig(rel_l1_thresh=-0.5) self.assertIn("must be positive", str(context.exception)) def test_invalid_coefficients_length(self): """Test wrong coefficient count raises ValueError.""" with self.assertRaises(ValueError) as context: - FluxTeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, 3.0]) + TeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, 3.0]) self.assertIn("exactly 5 elements", str(context.exception)) def test_invalid_coefficients_type(self): """Test invalid coefficient types raise TypeError.""" with self.assertRaises(TypeError) as context: - FluxTeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, "invalid", 4.0, 5.0]) + TeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, "invalid", 4.0, 5.0]) self.assertIn("must be numbers", str(context.exception)) def test_warning_very_low_threshold(self): """Test warning is issued for very low threshold.""" with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - FluxTeaCacheConfig(rel_l1_thresh=0.01) + TeaCacheConfig(rel_l1_thresh=0.01) self.assertEqual(len(w), 1) self.assertIn("very low", str(w[0].message)) @@ -66,32 +66,32 @@ def test_warning_very_high_threshold(self): """Test warning is issued for very high threshold.""" with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - FluxTeaCacheConfig(rel_l1_thresh=1.5) + TeaCacheConfig(rel_l1_thresh=1.5) self.assertEqual(len(w), 1) self.assertIn("very high", str(w[0].message)) def test_config_repr(self): """Test __repr__ method works correctly.""" - config = FluxTeaCacheConfig(rel_l1_thresh=0.25) + config = TeaCacheConfig(rel_l1_thresh=0.25) repr_str = repr(config) - self.assertIn("FluxTeaCacheConfig", repr_str) + self.assertIn("TeaCacheConfig", repr_str) self.assertIn("0.25", repr_str) def test_custom_coefficients(self): """Test custom coefficients are accepted.""" custom_coeffs = [1.0, 2.0, 3.0, 4.0, 5.0] - config = FluxTeaCacheConfig(rel_l1_thresh=0.2, coefficients=custom_coeffs) + config = TeaCacheConfig(rel_l1_thresh=0.2, coefficients=custom_coeffs) self.assertEqual(config.coefficients, custom_coeffs) -class FluxTeaCacheStateTests(unittest.TestCase): - """Tests for FluxTeaCacheState.""" +class TeaCacheStateTests(unittest.TestCase): + """Tests for TeaCacheState.""" def test_state_initialization(self): """Test state initializes with correct default values.""" - from diffusers.hooks.flux_teacache import FluxTeaCacheState + from diffusers.hooks.teacache import TeaCacheState - state = FluxTeaCacheState() + state = TeaCacheState() self.assertEqual(state.cnt, 0) self.assertEqual(state.num_steps, 0) self.assertEqual(state.accumulated_rel_l1_distance, 0.0) @@ -100,9 +100,9 @@ def test_state_initialization(self): def test_state_reset(self): """Test state reset clears all values.""" - from diffusers.hooks.flux_teacache import FluxTeaCacheState + from diffusers.hooks.teacache import TeaCacheState - state = FluxTeaCacheState() + state = TeaCacheState() # Modify state state.cnt = 5 state.num_steps = 10 @@ -122,45 +122,48 @@ def test_state_reset(self): def test_state_repr(self): """Test __repr__ method works correctly.""" - from diffusers.hooks.flux_teacache import FluxTeaCacheState + from diffusers.hooks.teacache import TeaCacheState - state = FluxTeaCacheState() + state = TeaCacheState() state.cnt = 3 state.num_steps = 10 repr_str = repr(state) - self.assertIn("FluxTeaCacheState", repr_str) + self.assertIn("TeaCacheState", repr_str) self.assertIn("cnt=3", repr_str) self.assertIn("num_steps=10", repr_str) -class FluxTeaCacheHookTests(unittest.TestCase): - """Tests for FluxTeaCacheHook functionality.""" +class TeaCacheHookTests(unittest.TestCase): + """Tests for TeaCacheHook functionality.""" def test_hook_initialization(self): """Test hook initializes correctly with config.""" - from diffusers.hooks.flux_teacache import FluxTeaCacheHook + from diffusers.hooks.teacache import TeaCacheHook - config = FluxTeaCacheConfig(rel_l1_thresh=0.2) - hook = FluxTeaCacheHook(config) + config = TeaCacheConfig(rel_l1_thresh=0.2) + hook = TeaCacheHook(config) self.assertEqual(hook.config.rel_l1_thresh, 0.2) self.assertIsNotNone(hook.rescale_func) self.assertIsNotNone(hook.state_manager) - def test_apply_flux_teacache_validation(self): - """Test apply_flux_teacache validates input module type.""" - from diffusers.hooks import apply_flux_teacache + def test_apply_teacache_with_custom_extractor(self): + """Test apply_teacache works with custom extractor function.""" + from diffusers.hooks import apply_teacache - # Create a dummy module that's not a FluxTransformer2DModel class DummyModule(torch.nn.Module): pass module = DummyModule() - config = FluxTeaCacheConfig(rel_l1_thresh=0.2) - with self.assertRaises(ValueError) as context: - apply_flux_teacache(module, config) - self.assertIn("FLUX transformer models", str(context.exception)) + # Custom extractor function + def custom_extractor(mod, hidden_states, temb): + return hidden_states + + config = TeaCacheConfig(rel_l1_thresh=0.2, extract_modulated_input_fn=custom_extractor) + + # Should not raise - TeaCache is now model-agnostic + apply_teacache(module, config) if __name__ == "__main__": From 4d340206f6d90e585597e49f7523374fe16d780a Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Wed, 26 Nov 2025 11:08:04 +0000 Subject: [PATCH 09/25] add TeaCache hook tests and ensure cache integration passes style and quality checks --- src/diffusers/hooks/__init__.py | 2 +- src/diffusers/hooks/teacache.py | 146 +++++++++++++++------------- src/diffusers/models/cache_utils.py | 8 +- tests/hooks/test_teacache.py | 40 +++++++- 4 files changed, 121 insertions(+), 75 deletions(-) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 3ccd0056840f..b5f80914842f 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -19,10 +19,10 @@ from .context_parallel import apply_context_parallel from .faster_cache import FasterCacheConfig, apply_faster_cache from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache - from .teacache import TeaCacheConfig, apply_teacache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig + from .teacache import TeaCacheConfig, apply_teacache diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index c0fa7bd9db28..b7697d1648ee 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -46,38 +46,36 @@ class TeaCacheConfig: r""" Configuration for [TeaCache](https://liewfeng.github.io/TeaCache/) applied to transformer models. - TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model - inference by reusing transformer block computations when consecutive timestep embeddings are similar. It uses - polynomial rescaling of L1 distances between modulated inputs to intelligently decide when to cache. + TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model inference + by reusing transformer block computations when consecutive timestep embeddings are similar. It uses polynomial + rescaling of L1 distances between modulated inputs to intelligently decide when to cache. Args: rel_l1_thresh (`float`, defaults to `0.2`): - Threshold for accumulated relative L1 distance. When the accumulated distance is below this threshold, - the cached residual from the previous timestep is reused instead of computing the full transformer. - Based on the original TeaCache paper, values in the range [0.1, 0.3] work best for balancing speed - and quality: + Threshold for accumulated relative L1 distance. When the accumulated distance is below this threshold, the + cached residual from the previous timestep is reused instead of computing the full transformer. Based on + the original TeaCache paper, values in the range [0.1, 0.3] work best for balancing speed and quality: - 0.25 for ~1.5x speedup with minimal quality loss - 0.4 for ~1.8x speedup with slight quality loss - 0.6 for ~2.0x speedup with noticeable quality loss - 0.8 for ~2.25x speedup with significant quality loss Higher thresholds lead to more aggressive caching and faster inference, but may reduce output quality. coefficients (`List[float]`, *optional*, defaults to polynomial coefficients from TeaCache paper): - Polynomial coefficients used for rescaling the raw L1 distance. These coefficients - transform the relative L1 distance into a model-specific caching signal. If not provided, defaults - to the coefficients determined for FLUX models in the TeaCache paper: - [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]. - The polynomial is evaluated as: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]` where x is the - relative L1 distance. + Polynomial coefficients used for rescaling the raw L1 distance. These coefficients transform the relative + L1 distance into a model-specific caching signal. If not provided, defaults to the coefficients determined + for FLUX models in the TeaCache paper: [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, + 2.64230861e-01]. The polynomial is evaluated as: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]` where x + is the relative L1 distance. extract_modulated_input_fn (`Callable`, *optional*, defaults to auto-detection): - Function to extract modulated input from the transformer module. Takes (module, hidden_states, timestep_emb) - and returns the modulated input tensor. If not provided, auto-detects based on model type. + Function to extract modulated input from the transformer module. Takes (module, hidden_states, + timestep_emb) and returns the modulated input tensor. If not provided, auto-detects based on model type. current_timestep_callback (`Callable[[], int]`, *optional*, defaults to `None`): - Callback function that returns the current timestep during inference. This is used internally for - debugging and statistics tracking. If not provided, TeaCache will still function correctly. + Callback function that returns the current timestep during inference. This is used internally for debugging + and statistics tracking. If not provided, TeaCache will still function correctly. num_inference_steps_callback (`Callable[[], int]`, *optional*, defaults to `None`): - Callback function that returns the total number of inference steps. This is used to ensure the first - and last timesteps are always computed (never cached) for maximum quality. If not provided, TeaCache - will attempt to detect the number of steps automatically from the pipeline. + Callback function that returns the total number of inference steps. This is used to ensure the first and + last timesteps are always computed (never cached) for maximum quality. If not provided, TeaCache will + attempt to detect the number of steps automatically from the pipeline. Examples: ```python @@ -102,6 +100,7 @@ class TeaCacheConfig: pipe.transformer.disable_cache() ``` """ + rel_l1_thresh: float = 0.2 coefficients: Optional[List[float]] = None extract_modulated_input_fn: Optional[Callable] = None @@ -123,24 +122,25 @@ def __post_init__(self): ) if self.rel_l1_thresh < 0.05: import warnings + warnings.warn( f"rel_l1_thresh={self.rel_l1_thresh} is very low and may result in minimal caching. " f"Consider using values between 0.1 and 0.3 for optimal performance.", - UserWarning + UserWarning, ) if self.rel_l1_thresh > 1.0: import warnings + warnings.warn( f"rel_l1_thresh={self.rel_l1_thresh} is very high and may cause quality degradation. " f"Consider using values between 0.1 and 0.6 for better quality-speed tradeoff.", - UserWarning + UserWarning, ) # Set default coefficients if not provided if self.coefficients is None: # Original FLUX coefficients from TeaCache paper - self.coefficients = [4.98651651e+02, -2.83781631e+02, - 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + self.coefficients = [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01] # Validate coefficients if not isinstance(self.coefficients, (list, tuple)): @@ -156,8 +156,7 @@ def __post_init__(self): ) if not all(isinstance(c, (int, float)) for c in self.coefficients): raise TypeError( - f"All coefficients must be numbers. " - f"Got types: {[type(c).__name__ for c in self.coefficients]}" + f"All coefficients must be numbers. Got types: {[type(c).__name__ for c in self.coefficients]}" ) def __repr__(self) -> str: @@ -176,27 +175,28 @@ class TeaCacheState(BaseState): """ State management for TeaCache hook. - This class tracks the caching state across diffusion timesteps, managing counters, accumulated distances, - and cached values needed for the TeaCache algorithm. The state persists across multiple forward passes - during a single inference run and is automatically reset when a new inference begins. + This class tracks the caching state across diffusion timesteps, managing counters, accumulated distances, and + cached values needed for the TeaCache algorithm. The state persists across multiple forward passes during a single + inference run and is automatically reset when a new inference begins. Attributes: cnt (int): - Current timestep counter, incremented with each forward pass. Used to identify first/last timesteps - which are always computed (never cached) for maximum quality. + Current timestep counter, incremented with each forward pass. Used to identify first/last timesteps which + are always computed (never cached) for maximum quality. num_steps (int): Total number of inference steps for the current run. Used to identify the last timestep. Automatically detected from callbacks or pipeline attributes if not explicitly set. accumulated_rel_l1_distance (float): - Running accumulator for rescaled L1 distances between consecutive modulated inputs. Compared against - the threshold to make caching decisions. Reset to 0 when the decision is made to recompute. + Running accumulator for rescaled L1 distances between consecutive modulated inputs. Compared against the + threshold to make caching decisions. Reset to 0 when the decision is made to recompute. previous_modulated_input (torch.Tensor): - Modulated input from the previous timestep, extracted from the first transformer block's norm1 layer. - Used for computing L1 distance to determine similarity between consecutive timesteps. + Modulated input from the previous timestep, extracted from the first transformer block's norm1 layer. Used + for computing L1 distance to determine similarity between consecutive timesteps. previous_residual (torch.Tensor): Cached residual (output - input) from the previous timestep's full transformer computation. Applied directly when caching is triggered instead of computing all transformer blocks. """ + def __init__(self): self.cnt = 0 self.num_steps = 0 @@ -228,9 +228,9 @@ class TeaCacheHook(ModelHook): """ ModelHook implementing TeaCache for transformer models. - This hook intercepts transformer forward pass and implements adaptive caching based on timestep - embedding similarity. It extracts modulated inputs, computes L1 distances, - applies polynomial rescaling, and decides whether to reuse cached residuals or compute full transformer blocks. + This hook intercepts transformer forward pass and implements adaptive caching based on timestep embedding + similarity. It extracts modulated inputs, computes L1 distances, applies polynomial rescaling, and decides whether + to reuse cached residuals or compute full transformer blocks. The hook follows the original TeaCache algorithm from the paper: 1. Extract modulated input using provided extractor function @@ -259,7 +259,7 @@ def __init__(self, config: TeaCacheConfig): self.rescale_func = np.poly1d(config.coefficients) self.state_manager = StateManager(TeaCacheState, (), {}) self.extractor_fn = None - + def initialize_hook(self, module): self.state_manager.set_context("teacache") # Auto-detect extractor if not provided @@ -268,15 +268,16 @@ def initialize_hook(self, module): else: self.extractor_fn = self.config.extract_modulated_input_fn return module - - def new_forward(self, module, hidden_states, timestep, pooled_projections, - encoder_hidden_states, txt_ids, img_ids, **kwargs): + + def new_forward( + self, module, hidden_states, timestep, pooled_projections, encoder_hidden_states, txt_ids, img_ids, **kwargs + ): """ Replace FLUX transformer forward pass with TeaCache-enabled version. - This method implements the full TeaCache algorithm inline, processing transformer blocks directly instead - of calling the original forward method. It extracts modulated inputs, makes caching decisions, and either - applies cached residuals (fast path) or computes full transformer blocks (slow path). + This method implements the full TeaCache algorithm inline, processing transformer blocks directly instead of + calling the original forward method. It extracts modulated inputs, makes caching decisions, and either applies + cached residuals (fast path) or computes full transformer blocks (slow path). Args: module: The FluxTransformer2DModel instance. @@ -300,23 +301,23 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, state.accumulated_rel_l1_distance = 0.0 state.previous_modulated_input = None state.previous_residual = None - + # Set num_steps on first timestep if not already set if state.cnt == 0 and state.num_steps == 0: if self.config.num_inference_steps_callback is not None: state.num_steps = self.config.num_inference_steps_callback() # If still not set, try to get from module attribute (set by pipeline) - if state.num_steps == 0 and hasattr(module, 'num_steps'): + if state.num_steps == 0 and hasattr(module, "num_steps"): state.num_steps = module.num_steps - + # Process inputs like original TeaCache # Must process hidden_states through x_embedder first hidden_states = module.x_embedder(hidden_states) # Extract timestep embedding timestep_scaled = timestep.to(hidden_states.dtype) * 1000 - if kwargs.get('guidance') is not None: - guidance = kwargs['guidance'].to(hidden_states.dtype) * 1000 + if kwargs.get("guidance") is not None: + guidance = kwargs["guidance"].to(hidden_states.dtype) * 1000 temb = module.time_text_embed(timestep_scaled, guidance, pooled_projections) else: temb = module.time_text_embed(timestep_scaled, pooled_projections) @@ -363,7 +364,7 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=kwargs.get('joint_attention_kwargs'), + joint_attention_kwargs=kwargs.get("joint_attention_kwargs"), ) # Process through single transformer blocks @@ -374,7 +375,7 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=kwargs.get('joint_attention_kwargs'), + joint_attention_kwargs=kwargs.get("joint_attention_kwargs"), ) # Cache the residual @@ -388,7 +389,7 @@ def new_forward(self, module, hidden_states, timestep, pooled_projections, output = module.proj_out(hidden_states) return output - + def _should_compute_full_transformer(self, state, modulated_inp): """ Determine whether to compute full transformer blocks or reuse cached residual. @@ -411,31 +412,37 @@ def _should_compute_full_transformer(self, state, modulated_inp): if state.cnt == 0: state.accumulated_rel_l1_distance = 0 return True - + # compute last timestep (if num_steps is set) if state.num_steps > 0 and state.cnt == state.num_steps - 1: state.accumulated_rel_l1_distance = 0 return True - + # Need previous modulated input for comparison if state.previous_modulated_input is None: return True - - # Compute relative L1 distance - rel_distance = ((modulated_inp - state.previous_modulated_input).abs().mean() - / state.previous_modulated_input.abs().mean()).cpu().item() - + + # Compute relative L1 distance + rel_distance = ( + ( + (modulated_inp - state.previous_modulated_input).abs().mean() + / state.previous_modulated_input.abs().mean() + ) + .cpu() + .item() + ) + # Apply polynomial rescaling rescaled_distance = self.rescale_func(rel_distance) state.accumulated_rel_l1_distance += rescaled_distance - + # Make decision based on accumulated threshold if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh: - return False + return False else: state.accumulated_rel_l1_distance = 0 # Reset accumulator - return True - + return True + def reset_state(self, module): self.state_manager.reset() return module @@ -445,9 +452,9 @@ def apply_teacache(module, config: TeaCacheConfig): """ Apply TeaCache optimization to a transformer model. - This function registers a TeaCacheHook on the provided transformer, enabling adaptive caching of - transformer block computations based on timestep embedding similarity. The hook intercepts the forward pass - and implements the TeaCache algorithm to achieve 1.5x-2x speedup with minimal quality loss. + This function registers a TeaCacheHook on the provided transformer, enabling adaptive caching of transformer block + computations based on timestep embedding similarity. The hook intercepts the forward pass and implements the + TeaCache algorithm to achieve 1.5x-2x speedup with minimal quality loss. Args: module: The transformer model to optimize (e.g., FluxTransformer2DModel, CogVideoXTransformer3DModel). @@ -474,9 +481,8 @@ def apply_teacache(module, config: TeaCacheConfig): ``` Note: - For most use cases, it's recommended to use the CacheMixin interface: - `pipe.transformer.enable_teacache(...)` which provides additional convenience methods - like `disable_cache()` for easy toggling. + For most use cases, it's recommended to use the CacheMixin interface: `pipe.transformer.enable_teacache(...)` + which provides additional convenience methods like `disable_cache()` for easy toggling. """ # Register hook on main transformer registry = HookRegistry.check_if_exists_or_initialize(module) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 83c595b60507..cded3c944f23 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -94,7 +94,13 @@ def enable_cache(self, config) -> None: self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TeaCacheConfig + from ..hooks import ( + FasterCacheConfig, + FirstBlockCacheConfig, + HookRegistry, + PyramidAttentionBroadcastConfig, + TeaCacheConfig, + ) from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK diff --git a/tests/hooks/test_teacache.py b/tests/hooks/test_teacache.py index 5fc061d374ff..6af808093c4b 100644 --- a/tests/hooks/test_teacache.py +++ b/tests/hooks/test_teacache.py @@ -17,7 +17,7 @@ import torch -from diffusers.hooks import TeaCacheConfig, HookRegistry +from diffusers.hooks import HookRegistry, TeaCacheConfig class TeaCacheConfigTests(unittest.TestCase): @@ -147,12 +147,40 @@ def test_hook_initialization(self): self.assertIsNotNone(hook.rescale_func) self.assertIsNotNone(hook.state_manager) + def test_should_compute_full_transformer_logic(self): + """Test _should_compute_full_transformer decision logic.""" + from diffusers.hooks.teacache import TeaCacheHook, TeaCacheState + + config = TeaCacheConfig(rel_l1_thresh=1.0, coefficients=[1, 0, 0, 0, 0]) + hook = TeaCacheHook(config) + state = TeaCacheState() + + x0 = torch.ones(1, 4) + x1 = torch.ones(1, 4) * 1.1 + + # First step should always compute + self.assertTrue(hook._should_compute_full_transformer(state, x0)) + + state.previous_modulated_input = x0 + state.cnt = 1 + state.num_steps = 4 + + # Middle step: accumulate distance and stay below threshold => reuse cache + self.assertFalse(hook._should_compute_full_transformer(state, x1)) + + # Last step: must compute regardless of distance + state.cnt = state.num_steps - 1 + self.assertTrue(hook._should_compute_full_transformer(state, x1)) + def test_apply_teacache_with_custom_extractor(self): """Test apply_teacache works with custom extractor function.""" from diffusers.hooks import apply_teacache + from diffusers.models import CacheMixin - class DummyModule(torch.nn.Module): - pass + class DummyModule(torch.nn.Module, CacheMixin): + def __init__(self): + super().__init__() + self.dummy = torch.nn.Linear(4, 4) module = DummyModule() @@ -165,6 +193,12 @@ def custom_extractor(mod, hidden_states, temb): # Should not raise - TeaCache is now model-agnostic apply_teacache(module, config) + # Verify registry and disable path work + registry = HookRegistry.check_if_exists_or_initialize(module) + self.assertIn("teacache", registry.hooks) + + module.disable_cache() + if __name__ == "__main__": unittest.main() From 9dab52f37b55d4de9f641ff229c7f12e544c47b9 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Tue, 2 Dec 2025 09:50:46 +0000 Subject: [PATCH 10/25] Add multi-model TeaCache support for Mochi, Lumina2, and CogVideoX with auto-detection --- src/diffusers/hooks/teacache.py | 589 ++++++++++++++++++++++++++++++-- tests/hooks/test_teacache.py | 265 ++++++++++++++ 2 files changed, 829 insertions(+), 25 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index b7697d1648ee..c5ffb6d6a711 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -26,19 +26,67 @@ _TEACACHE_HOOK = "teacache" +# Model-specific polynomial coefficients from TeaCache paper/reference implementations +_MODEL_COEFFICIENTS = { + "Flux": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01], + "Mochi": [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03], + "Lumina2": [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344], + "CogVideoX": [ + -1.53880483e03, + 8.43202495e02, + -1.34363087e02, + 7.97131516e00, + -5.23162339e-02, + ], # Default to 5b variant + # CogVideoX model variants with specific coefficients + "CogVideoX-2b": [-3.10658903e01, 2.54732368e01, -5.92380459e00, 1.75769064e00, -3.61568434e-03], + "CogVideoX-5b": [-1.53880483e03, 8.43202495e02, -1.34363087e02, 7.97131516e00, -5.23162339e-02], + "CogVideoX1.5-5B": [2.50210439e02, -1.65061612e02, 3.57804877e01, -7.81551492e-01, 3.58559703e-02], + "CogVideoX1.5-5B-I2V": [1.22842302e02, -1.04088754e02, 2.62981677e01, -3.06001e-01, 3.71213220e-02], +} + def _flux_modulated_input_extractor(module, hidden_states, timestep_emb): """Extract modulated input for FLUX models.""" return module.transformer_blocks[0].norm1(hidden_states, emb=timestep_emb)[0] +def _mochi_modulated_input_extractor(module, hidden_states, timestep_emb): + """Extract modulated input for Mochi models.""" + # Mochi norm1 returns tuple: (modulated_inp, gate_msa, scale_mlp, gate_mlp) + return module.transformer_blocks[0].norm1(hidden_states, timestep_emb)[0] + + +def _lumina2_modulated_input_extractor(module, hidden_states, timestep_emb): + """Extract modulated input for Lumina2 models.""" + # Lumina2 uses 'layers' instead of 'transformer_blocks' and norm1 returns tuple + # Note: This extractor expects input_to_main_loop as hidden_states (after preprocessing) + return module.layers[0].norm1(hidden_states, timestep_emb)[0] + + +def _cogvideox_modulated_input_extractor(module, hidden_states, timestep_emb): + """Extract modulated input for CogVideoX models.""" + # CogVideoX uses the timestep embedding directly, not from a block + return timestep_emb + + def _auto_detect_extractor(module): """Auto-detect and return appropriate extractor based on model type.""" module_class_name = module.__class__.__name__ if "Flux" in module_class_name: return _flux_modulated_input_extractor - # Add more model types as needed - return _flux_modulated_input_extractor # Default to FLUX for now + elif "Mochi" in module_class_name: + return _mochi_modulated_input_extractor + elif "Lumina2" in module_class_name: + return _lumina2_modulated_input_extractor + elif "CogVideoX" in module_class_name: + return _cogvideox_modulated_input_extractor + # Default to FLUX for backward compatibility + logger.warning( + f"TeaCache: Unknown model type {module_class_name}, defaulting to FLUX extractor. " + f"Results may be incorrect. Please provide extract_modulated_input_fn explicitly." + ) + return _flux_modulated_input_extractor @dataclass @@ -50,6 +98,9 @@ class TeaCacheConfig: by reusing transformer block computations when consecutive timestep embeddings are similar. It uses polynomial rescaling of L1 distances between modulated inputs to intelligently decide when to cache. + Currently supports: FLUX, Mochi, Lumina2, and CogVideoX models. Model type is auto-detected, and model-specific + polynomial coefficients are automatically applied. + Args: rel_l1_thresh (`float`, defaults to `0.2`): Threshold for accumulated relative L1 distance. When the accumulated distance is below this threshold, the @@ -137,10 +188,12 @@ def __post_init__(self): UserWarning, ) - # Set default coefficients if not provided + # Set default coefficients if not provided (will be auto-detected in hook initialization) if self.coefficients is None: - # Original FLUX coefficients from TeaCache paper - self.coefficients = [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01] + # Default to FLUX coefficients (will be overridden by auto-detection if model is recognized) + self.coefficients = _MODEL_COEFFICIENTS.get( + "Flux", [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01] + ) # Validate coefficients if not isinstance(self.coefficients, (list, tuple)): @@ -195,6 +248,9 @@ class TeaCacheState(BaseState): previous_residual (torch.Tensor): Cached residual (output - input) from the previous timestep's full transformer computation. Applied directly when caching is triggered instead of computing all transformer blocks. + previous_residual_encoder (torch.Tensor, optional): + Cached encoder residual for models that cache both encoder and hidden_states residuals (e.g., CogVideoX). + None for models that only cache hidden_states residual. """ def __init__(self): @@ -203,6 +259,8 @@ def __init__(self): self.accumulated_rel_l1_distance = 0.0 self.previous_modulated_input = None self.previous_residual = None + # For models that cache both encoder and hidden_states residuals (e.g., CogVideoX) + self.previous_residual_encoder = None def reset(self): """Reset all state variables to initial values for a new inference run.""" @@ -211,6 +269,7 @@ def reset(self): self.accumulated_rel_l1_distance = 0.0 self.previous_modulated_input = None self.previous_residual = None + self.previous_residual_encoder = None def __repr__(self) -> str: return ( @@ -256,38 +315,100 @@ class TeaCacheHook(ModelHook): def __init__(self, config: TeaCacheConfig): super().__init__() self.config = config - self.rescale_func = np.poly1d(config.coefficients) + # Set default rescale_func with config coefficients (will be updated in initialize_hook if needed) + # This ensures rescale_func is always valid, even if initialize_hook isn't called (e.g., in tests) + default_coeffs = config.coefficients if config.coefficients else _MODEL_COEFFICIENTS["Flux"] + self.rescale_func = np.poly1d(default_coeffs) self.state_manager = StateManager(TeaCacheState, (), {}) self.extractor_fn = None + self.model_type = None def initialize_hook(self, module): self.state_manager.set_context("teacache") - # Auto-detect extractor if not provided + + # Auto-detect model type and extractor + module_class_name = module.__class__.__name__ if self.config.extract_modulated_input_fn is None: self.extractor_fn = _auto_detect_extractor(module) + # Detect model type for coefficient auto-detection + if "Flux" in module_class_name: + self.model_type = "Flux" + elif "Mochi" in module_class_name: + self.model_type = "Mochi" + elif "Lumina2" in module_class_name: + self.model_type = "Lumina2" + elif "CogVideoX" in module_class_name: + # Try to detect specific CogVideoX variant from config + if hasattr(module, "config") and hasattr(module.config, "_name_or_path"): + name_or_path = module.config._name_or_path.lower() + if "1.5" in name_or_path or "1-5" in name_or_path: + if "i2v" in name_or_path: + self.model_type = "CogVideoX1.5-5B-I2V" + else: + self.model_type = "CogVideoX1.5-5B" + elif "2b" in name_or_path: + self.model_type = "CogVideoX-2b" + elif "5b" in name_or_path: + self.model_type = "CogVideoX-5b" + else: + self.model_type = "CogVideoX" # Default to generic 5b + else: + self.model_type = "CogVideoX" # Default to generic 5b + else: + self.model_type = "Flux" # Default else: self.extractor_fn = self.config.extract_modulated_input_fn + self.model_type = "Flux" # Default if custom extractor provided + + # Auto-set coefficients from registry if not explicitly provided + if self.config.coefficients is None or ( + self.model_type in _MODEL_COEFFICIENTS and self.config.coefficients == _MODEL_COEFFICIENTS.get("Flux") + ): + if self.model_type in _MODEL_COEFFICIENTS: + self.config.coefficients = _MODEL_COEFFICIENTS[self.model_type] + logger.info(f"TeaCache: Auto-detected {self.model_type} coefficients") + + # Initialize rescale function with final coefficients + self.rescale_func = np.poly1d(self.config.coefficients) + return module - def new_forward( + def new_forward(self, module, *args, **kwargs): + """ + Route to model-specific forward handler based on detected model type. + """ + module_class_name = module.__class__.__name__ + + if "Flux" in module_class_name: + return self._handle_flux_forward(module, *args, **kwargs) + elif "Mochi" in module_class_name: + return self._handle_mochi_forward(module, *args, **kwargs) + elif "Lumina2" in module_class_name: + return self._handle_lumina2_forward(module, *args, **kwargs) + elif "CogVideoX" in module_class_name: + return self._handle_cogvideox_forward(module, *args, **kwargs) + else: + # Default to FLUX handler for backward compatibility + logger.warning( + f"TeaCache: Unknown model type {module_class_name}, using FLUX handler. Results may be incorrect." + ) + return self._handle_flux_forward(module, *args, **kwargs) + + def _handle_flux_forward( self, module, hidden_states, timestep, pooled_projections, encoder_hidden_states, txt_ids, img_ids, **kwargs ): """ - Replace FLUX transformer forward pass with TeaCache-enabled version. - - This method implements the full TeaCache algorithm inline, processing transformer blocks directly instead of - calling the original forward method. It extracts modulated inputs, makes caching decisions, and either applies - cached residuals (fast path) or computes full transformer blocks (slow path). + Handle FLUX transformer forward pass with TeaCache. Args: module: The FluxTransformer2DModel instance. - hidden_states (`torch.Tensor`): Input latent tensor of shape (batch, channels, height, width). + hidden_states (`torch.Tensor`): Input latent tensor. timestep (`torch.Tensor`): Current diffusion timestep. - pooled_projections (`torch.Tensor`): Pooled text embeddings for timestep conditioning. - encoder_hidden_states (`torch.Tensor`): Text encoder outputs for cross-attention. + pooled_projections (`torch.Tensor`): Pooled text embeddings. + encoder_hidden_states (`torch.Tensor`): Text encoder outputs. txt_ids (`torch.Tensor`): Position IDs for text tokens. img_ids (`torch.Tensor`): Position IDs for image tokens. - **kwargs: Additional arguments including 'guidance' and 'joint_attention_kwargs'. + **kwargs: Additional arguments. Returns: `torch.Tensor`: Denoised output tensor. @@ -332,17 +453,9 @@ def new_forward( if not should_calc: # Fast path: apply cached residual - logger.debug( - f"TeaCache: reusing cached residual at step {state.cnt}/{state.num_steps} " - f"(accumulated distance: {state.accumulated_rel_l1_distance:.6f})" - ) hidden_states = hidden_states + state.previous_residual else: # Slow path: full computation inline (like original TeaCache) - logger.debug( - f"TeaCache: computing full transformer at step {state.cnt}/{state.num_steps} " - f"(accumulated distance: {state.accumulated_rel_l1_distance:.6f})" - ) ori_hidden_states = hidden_states.clone() # Process encoder_hidden_states @@ -390,6 +503,432 @@ def new_forward( return output + def _handle_mochi_forward( + self, + module, + hidden_states, + encoder_hidden_states, + timestep, + encoder_attention_mask, + attention_kwargs=None, + return_dict=True, + ): + """ + Handle Mochi transformer forward pass with TeaCache. + + Args: + module: The MochiTransformer3DModel instance. + hidden_states (`torch.Tensor`): Input latent tensor. + encoder_hidden_states (`torch.Tensor`): Text encoder outputs. + timestep (`torch.Tensor`): Current diffusion timestep. + encoder_attention_mask (`torch.Tensor`): Attention mask for encoder. + attention_kwargs (`dict`, optional): Additional attention arguments. + return_dict (`bool`): Whether to return a dict. + + Returns: + `torch.Tensor` or `Transformer2DModelOutput`: Denoised output. + """ + from diffusers.models.modeling_outputs import Transformer2DModelOutput + from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(module, lora_scale) + + state = self.state_manager.get_state() + + # Reset counter if we've completed all steps + if state.cnt == state.num_steps and state.num_steps > 0: + logger.info("TeaCache inference completed") + state.cnt = 0 + state.accumulated_rel_l1_distance = 0.0 + state.previous_modulated_input = None + state.previous_residual = None + + # Set num_steps on first timestep if not already set + if state.cnt == 0 and state.num_steps == 0: + if self.config.num_inference_steps_callback is not None: + state.num_steps = self.config.num_inference_steps_callback() + elif hasattr(module, "num_steps"): + state.num_steps = module.num_steps + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p = module.config.patch_size + post_patch_height = height // p + post_patch_width = width // p + + # Process time embedding + temb, encoder_hidden_states = module.time_embed( + timestep, + encoder_hidden_states, + encoder_attention_mask, + hidden_dtype=hidden_states.dtype, + ) + + # Process patch embedding + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = module.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + # Get rotary embeddings + image_rotary_emb = module.rope( + module.pos_frequencies, + num_frames, + post_patch_height, + post_patch_width, + device=hidden_states.device, + dtype=torch.float32, + ) + + # Extract modulated input + inp = hidden_states.clone() + temb_clone = temb.clone() + modulated_inp = self.extractor_fn(module, inp, temb_clone) + + # Make caching decision + should_calc = self._should_compute_full_transformer(state, modulated_inp) + + if not should_calc: + # Fast path: apply cached residual + hidden_states = hidden_states + state.previous_residual + else: + # Slow path: full computation + ori_hidden_states = hidden_states.clone() + + # Process through transformer blocks + for block in module.transformer_blocks: + if torch.is_grad_enabled() and module.gradient_checkpointing: + hidden_states, encoder_hidden_states = module._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + encoder_attention_mask, + image_rotary_emb, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + encoder_attention_mask=encoder_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + # Cache the residual + state.previous_residual = hidden_states - ori_hidden_states + + state.previous_modulated_input = modulated_inp + state.cnt += 1 + + # Apply final norm and projection + hidden_states = module.norm_out(hidden_states, temb) + hidden_states = module.proj_out(hidden_states) + + # Reshape output + hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) + hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if USE_PEFT_BACKEND: + unscale_lora_layers(module, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + def _handle_lumina2_forward( + self, + module, + hidden_states, + timestep, + encoder_hidden_states, + encoder_attention_mask, + attention_kwargs=None, + return_dict=True, + ): + """ + Handle Lumina2 transformer forward pass with TeaCache. + + Note: Lumina2 has complex preprocessing and uses 'layers' instead of 'transformer_blocks'. + The modulated input extraction happens after preprocessing to input_to_main_loop. + + Args: + module: The Lumina2Transformer2DModel instance. + hidden_states (`torch.Tensor`): Input latent tensor. + timestep (`torch.Tensor`): Current diffusion timestep. + encoder_hidden_states (`torch.Tensor`): Text encoder outputs. + encoder_attention_mask (`torch.Tensor`): Attention mask for encoder. + attention_kwargs (`dict`, optional): Additional attention arguments. + return_dict (`bool`): Whether to return a dict. + + Returns: + `torch.Tensor` or `Transformer2DModelOutput`: Denoised output. + """ + from diffusers.models.modeling_outputs import Transformer2DModelOutput + from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(module, lora_scale) + + state = self.state_manager.get_state() + + # Reset counter if we've completed all steps + if state.cnt == state.num_steps and state.num_steps > 0: + logger.info("TeaCache inference completed") + state.cnt = 0 + state.accumulated_rel_l1_distance = 0.0 + state.previous_modulated_input = None + state.previous_residual = None + + # Set num_steps on first timestep if not already set + if state.cnt == 0 and state.num_steps == 0: + if self.config.num_inference_steps_callback is not None: + state.num_steps = self.config.num_inference_steps_callback() + elif hasattr(module, "num_steps"): + state.num_steps = module.num_steps + + batch_size, _, height, width = hidden_states.shape + + # Lumina2 preprocessing (matches original forward) + temb, encoder_hidden_states_processed = module.time_caption_embed( + hidden_states, timestep, encoder_hidden_states + ) + ( + image_patch_embeddings, + context_rotary_emb, + noise_rotary_emb, + joint_rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = module.rope_embedder(hidden_states, encoder_attention_mask) + image_patch_embeddings = module.x_embedder(image_patch_embeddings) + + for layer in module.context_refiner: + encoder_hidden_states_processed = layer( + encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb + ) + for layer in module.noise_refiner: + image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb) + + max_seq_len = max(seq_lengths) + input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, module.config.hidden_size) + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len] + input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i] + + use_mask = len(set(seq_lengths)) > 1 + attention_mask_for_main_loop_arg = None + if use_mask: + mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + mask[i, :seq_len_val] = True + attention_mask_for_main_loop_arg = mask + + # Extract modulated input (after preprocessing) + modulated_inp = self.extractor_fn(module, input_to_main_loop, temb) + + # Make caching decision + should_calc = self._should_compute_full_transformer(state, modulated_inp) + + if not should_calc: + # Fast path: apply cached residual + processed_hidden_states = input_to_main_loop + state.previous_residual + else: + # Slow path: full computation + current_processing_states = input_to_main_loop + for layer in module.layers: + current_processing_states = layer( + current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb + ) + processed_hidden_states = current_processing_states + # Cache the residual + state.previous_residual = processed_hidden_states - input_to_main_loop + + state.previous_modulated_input = modulated_inp + state.cnt += 1 + + # Apply final norm and reshape + output_after_norm = module.norm_out(processed_hidden_states, temb) + p = module.config.patch_size + final_output_list = [] + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + image_part = output_after_norm[i][enc_len:seq_len_val] + h_p, w_p = height // p, width // p + reconstructed_image = ( + image_part.view(h_p, w_p, p, p, module.out_channels).permute(4, 0, 2, 1, 3).flatten(3, 4).flatten(1, 2) + ) + final_output_list.append(reconstructed_image) + + final_output_tensor = torch.stack(final_output_list, dim=0) + + if USE_PEFT_BACKEND: + unscale_lora_layers(module, lora_scale) + + if not return_dict: + return (final_output_tensor,) + return Transformer2DModelOutput(sample=final_output_tensor) + + def _handle_cogvideox_forward( + self, + module, + hidden_states, + encoder_hidden_states, + timestep, + timestep_cond=None, + ofs=None, + image_rotary_emb=None, + attention_kwargs=None, + return_dict=True, + ): + """ + Handle CogVideoX transformer forward pass with TeaCache. + + Note: CogVideoX uses timestep embedding directly (not from a block) and caches + both encoder_hidden_states and hidden_states residuals. + + Args: + module: The CogVideoXTransformer3DModel instance. + hidden_states (`torch.Tensor`): Input latent tensor. + encoder_hidden_states (`torch.Tensor`): Text encoder outputs. + timestep (`torch.Tensor`): Current diffusion timestep. + timestep_cond (`torch.Tensor`, optional): Additional timestep conditioning. + ofs (`torch.Tensor`, optional): Offset tensor. + image_rotary_emb (`torch.Tensor`, optional): Rotary embeddings. + attention_kwargs (`dict`, optional): Additional attention arguments. + return_dict (`bool`): Whether to return a dict. + + Returns: + `torch.Tensor` or `Transformer2DModelOutput`: Denoised output. + """ + from diffusers.models.modeling_outputs import Transformer2DModelOutput + from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(module, lora_scale) + + state = self.state_manager.get_state() + + # Reset counter if we've completed all steps + if state.cnt == state.num_steps and state.num_steps > 0: + logger.info("TeaCache inference completed") + state.cnt = 0 + state.accumulated_rel_l1_distance = 0.0 + state.previous_modulated_input = None + state.previous_residual = None + state.previous_residual_encoder = None + + # Set num_steps on first timestep if not already set + if state.cnt == 0 and state.num_steps == 0: + if self.config.num_inference_steps_callback is not None: + state.num_steps = self.config.num_inference_steps_callback() + elif hasattr(module, "num_steps"): + state.num_steps = module.num_steps + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # Process time embedding + timesteps = timestep + t_emb = module.time_proj(timesteps) + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = module.time_embedding(t_emb, timestep_cond) + + if module.ofs_embedding is not None: + ofs_emb = module.ofs_proj(ofs) + ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) + ofs_emb = module.ofs_embedding(ofs_emb) + emb = emb + ofs_emb + + # Process patch embedding + hidden_states = module.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = module.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # Extract modulated input (CogVideoX uses timestep embedding directly) + modulated_inp = self.extractor_fn(module, hidden_states, emb) + + # Make caching decision + should_calc = self._should_compute_full_transformer(state, modulated_inp) + + if not should_calc: + # Fast path: apply cached residuals (both encoder and hidden_states) + if state.previous_residual_encoder is not None: + hidden_states = hidden_states + state.previous_residual + encoder_hidden_states = encoder_hidden_states + state.previous_residual_encoder + else: + # Fallback: compute if encoder residual not cached + should_calc = True + else: + # Slow path: full computation + ori_hidden_states = hidden_states.clone() + ori_encoder_hidden_states = encoder_hidden_states.clone() + + # Process through transformer blocks + for block in module.transformer_blocks: + if torch.is_grad_enabled() and module.gradient_checkpointing: + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + lambda *args: block(*args), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + # Cache both residuals + state.previous_residual = hidden_states - ori_hidden_states + state.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states + + state.previous_modulated_input = modulated_inp + state.cnt += 1 + + # Apply final norm + if not module.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = module.norm_final(hidden_states) + else: + # CogVideoX-5B and CogVideoX1.5-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = module.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + output = module.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + unscale_lora_layers(module, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + def _should_compute_full_transformer(self, state, modulated_inp): """ Determine whether to compute full transformer blocks or reuse cached residual. diff --git a/tests/hooks/test_teacache.py b/tests/hooks/test_teacache.py index 6af808093c4b..5fb265509b5a 100644 --- a/tests/hooks/test_teacache.py +++ b/tests/hooks/test_teacache.py @@ -200,5 +200,270 @@ def custom_extractor(mod, hidden_states, temb): module.disable_cache() +class TeaCacheMultiModelTests(unittest.TestCase): + """Tests for TeaCache multi-model support (Mochi, Lumina2, CogVideoX).""" + + def test_model_coefficient_registry(self): + """Test that model coefficients are properly registered.""" + from diffusers.hooks.teacache import _MODEL_COEFFICIENTS + + self.assertIn("Flux", _MODEL_COEFFICIENTS) + self.assertIn("Mochi", _MODEL_COEFFICIENTS) + self.assertIn("Lumina2", _MODEL_COEFFICIENTS) + self.assertIn("CogVideoX", _MODEL_COEFFICIENTS) + + # Verify all coefficients are 5-element lists + for model_name, coeffs in _MODEL_COEFFICIENTS.items(): + self.assertEqual(len(coeffs), 5, f"{model_name} coefficients should have 5 elements") + self.assertTrue( + all(isinstance(c, (int, float)) for c in coeffs), f"{model_name} coefficients should be numbers" + ) + + def test_mochi_extractor(self): + """Test Mochi modulated input extractor.""" + from diffusers import MochiTransformer3DModel + from diffusers.hooks.teacache import _mochi_modulated_input_extractor + + # Create a minimal Mochi model for testing + model = MochiTransformer3DModel( + patch_size=2, + num_attention_heads=2, + attention_head_dim=8, + num_layers=2, + in_channels=4, + text_embed_dim=16, + time_embed_dim=4, + ) + + hidden_states = torch.randn(2, 4, 2, 8, 8) + timestep = torch.randint(0, 1000, (2,)) + encoder_hidden_states = torch.randn(2, 16, 16) + encoder_attention_mask = torch.ones(2, 16).bool() + + # Get timestep embedding + temb, _ = model.time_embed( + timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype + ) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = model.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (2, -1)).flatten(1, 2) + + # Test extractor + modulated_inp = _mochi_modulated_input_extractor(model, hidden_states, temb) + self.assertIsInstance(modulated_inp, torch.Tensor) + self.assertEqual(modulated_inp.shape[0], hidden_states.shape[0]) + + def test_lumina2_extractor(self): + """Test Lumina2 modulated input extractor with simplified setup.""" + from diffusers import Lumina2Transformer2DModel + from diffusers.hooks.teacache import _lumina2_modulated_input_extractor + + # Create a minimal Lumina2 model for testing + model = Lumina2Transformer2DModel( + sample_size=16, + patch_size=2, + in_channels=4, + hidden_size=24, + num_layers=2, + num_refiner_layers=1, + num_attention_heads=3, + num_kv_heads=1, + ) + + # Create properly shaped inputs that match what the extractor expects + # The extractor expects input_to_main_loop (already preprocessed concatenated text+image tokens) + batch_size = 2 + seq_len = 100 # combined text + image sequence + hidden_size = model.config.hidden_size + + # Simulate input_to_main_loop (already preprocessed) + input_to_main_loop = torch.randn(batch_size, seq_len, hidden_size) + temb = torch.randn(batch_size, hidden_size) + + # Test extractor + modulated_inp = _lumina2_modulated_input_extractor(model, input_to_main_loop, temb) + self.assertIsInstance(modulated_inp, torch.Tensor) + self.assertEqual(modulated_inp.shape[0], batch_size) + + def test_cogvideox_extractor(self): + """Test CogVideoX modulated input extractor.""" + from diffusers import CogVideoXTransformer3DModel + from diffusers.hooks.teacache import _cogvideox_modulated_input_extractor + + # Create a minimal CogVideoX model for testing + model = CogVideoXTransformer3DModel( + num_attention_heads=2, + attention_head_dim=8, + in_channels=4, + num_layers=2, + text_embed_dim=16, + time_embed_dim=4, + ) + + hidden_states = torch.randn(2, 2, 4, 8, 8) + timestep = torch.randint(0, 1000, (2,)) + + # Get timestep embedding + t_emb = model.time_proj(timestep) + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = model.time_embedding(t_emb, None) + + # Test extractor (should return emb directly) + modulated_inp = _cogvideox_modulated_input_extractor(model, hidden_states, emb) + self.assertIsInstance(modulated_inp, torch.Tensor) + self.assertEqual(modulated_inp.shape, emb.shape) + + def test_auto_detect_mochi(self): + """Test auto-detection for Mochi models.""" + from diffusers import MochiTransformer3DModel + from diffusers.hooks import TeaCacheConfig, apply_teacache + from diffusers.hooks.teacache import _MODEL_COEFFICIENTS, _auto_detect_extractor + + model = MochiTransformer3DModel( + patch_size=2, + num_attention_heads=2, + attention_head_dim=8, + num_layers=2, + in_channels=4, + text_embed_dim=16, + time_embed_dim=4, + ) + + # Test extractor detection + extractor = _auto_detect_extractor(model) + self.assertIsNotNone(extractor) + + # Test coefficient auto-detection + config = TeaCacheConfig(rel_l1_thresh=0.2) + apply_teacache(model, config) + + registry = HookRegistry.check_if_exists_or_initialize(model) + hook = registry.get_hook("teacache") + self.assertIsNotNone(hook) + # Verify coefficients were auto-set + self.assertEqual(hook.config.coefficients, _MODEL_COEFFICIENTS["Mochi"]) + + model.disable_cache() + + def test_auto_detect_lumina2(self): + """Test auto-detection for Lumina2 models.""" + from diffusers import Lumina2Transformer2DModel + from diffusers.hooks import TeaCacheConfig, apply_teacache + from diffusers.hooks.teacache import _MODEL_COEFFICIENTS + + model = Lumina2Transformer2DModel( + sample_size=16, + patch_size=2, + in_channels=4, + hidden_size=24, + num_layers=2, + num_refiner_layers=1, + num_attention_heads=3, + num_kv_heads=1, + ) + + config = TeaCacheConfig(rel_l1_thresh=0.2) + apply_teacache(model, config) + + registry = HookRegistry.check_if_exists_or_initialize(model) + hook = registry.get_hook("teacache") + self.assertIsNotNone(hook) + # Verify coefficients were auto-set + self.assertEqual(hook.config.coefficients, _MODEL_COEFFICIENTS["Lumina2"]) + + # Lumina2 doesn't have CacheMixin, manually remove hook instead + registry.remove_hook("teacache") + + def test_auto_detect_cogvideox(self): + """Test auto-detection for CogVideoX models.""" + from diffusers import CogVideoXTransformer3DModel + from diffusers.hooks import TeaCacheConfig, apply_teacache + from diffusers.hooks.teacache import _MODEL_COEFFICIENTS + + model = CogVideoXTransformer3DModel( + num_attention_heads=2, + attention_head_dim=8, + in_channels=4, + num_layers=2, + text_embed_dim=16, + time_embed_dim=4, + ) + + config = TeaCacheConfig(rel_l1_thresh=0.2) + apply_teacache(model, config) + + registry = HookRegistry.check_if_exists_or_initialize(model) + hook = registry.get_hook("teacache") + self.assertIsNotNone(hook) + # Verify coefficients were auto-set + self.assertEqual(hook.config.coefficients, _MODEL_COEFFICIENTS["CogVideoX"]) + + model.disable_cache() + + def test_teacache_state_encoder_residual(self): + """Test that TeaCacheState supports encoder residual for CogVideoX.""" + from diffusers.hooks.teacache import TeaCacheState + + state = TeaCacheState() + self.assertIsNone(state.previous_residual_encoder) + + # Set encoder residual + state.previous_residual_encoder = torch.randn(2, 10, 16) + self.assertIsNotNone(state.previous_residual_encoder) + + # Reset should clear it + state.reset() + self.assertIsNone(state.previous_residual_encoder) + + def test_model_routing(self): + """Test that new_forward routes to correct handler based on model type.""" + from diffusers import CogVideoXTransformer3DModel, Lumina2Transformer2DModel, MochiTransformer3DModel + from diffusers.hooks.teacache import TeaCacheConfig, TeaCacheHook + + config = TeaCacheConfig(rel_l1_thresh=0.2) + + # Test Mochi routing + mochi_model = MochiTransformer3DModel( + patch_size=2, + num_attention_heads=2, + attention_head_dim=8, + num_layers=2, + in_channels=4, + text_embed_dim=16, + time_embed_dim=4, + ) + mochi_hook = TeaCacheHook(config) + mochi_hook.initialize_hook(mochi_model) + self.assertEqual(mochi_hook.model_type, "Mochi") + + # Test Lumina2 routing + lumina_model = Lumina2Transformer2DModel( + sample_size=16, + patch_size=2, + in_channels=4, + hidden_size=24, + num_layers=2, + num_refiner_layers=1, + num_attention_heads=3, + num_kv_heads=1, + ) + lumina_hook = TeaCacheHook(config) + lumina_hook.initialize_hook(lumina_model) + self.assertEqual(lumina_hook.model_type, "Lumina2") + + # Test CogVideoX routing + cogvideox_model = CogVideoXTransformer3DModel( + num_attention_heads=2, + attention_head_dim=8, + in_channels=4, + num_layers=2, + text_embed_dim=16, + time_embed_dim=4, + ) + cogvideox_hook = TeaCacheHook(config) + cogvideox_hook.initialize_hook(cogvideox_model) + self.assertEqual(cogvideox_hook.model_type, "CogVideoX") + + if __name__ == "__main__": unittest.main() From 4cb71d6ca5514aaec889bb4260ce3b63e5cdb269 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Mon, 8 Dec 2025 13:44:36 +0000 Subject: [PATCH 11/25] simplify model extract; use logger --- src/diffusers/hooks/teacache.py | 147 ++++++++++++++++++-------------- 1 file changed, 81 insertions(+), 66 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index c5ffb6d6a711..45d2936604e0 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -29,6 +29,7 @@ # Model-specific polynomial coefficients from TeaCache paper/reference implementations _MODEL_COEFFICIENTS = { "Flux": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01], + "FluxKontext": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02], "Mochi": [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03], "Lumina2": [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344], "CogVideoX": [ @@ -70,23 +71,59 @@ def _cogvideox_modulated_input_extractor(module, hidden_states, timestep_emb): return timestep_emb +# Extractor registry - maps model types to extraction functions +# Multiple model variants can share the same extractor +# Order matters: more specific variants first (e.g., CogVideoX1.5-5B-I2V before CogVideoX) +_EXTRACTOR_REGISTRY = { + "FluxKontext": _flux_modulated_input_extractor, + "Flux": _flux_modulated_input_extractor, + "Mochi": _mochi_modulated_input_extractor, + "Lumina2": _lumina2_modulated_input_extractor, + "CogVideoX1.5-5B-I2V": _cogvideox_modulated_input_extractor, + "CogVideoX1.5-5B": _cogvideox_modulated_input_extractor, + "CogVideoX-2b": _cogvideox_modulated_input_extractor, + "CogVideoX-5b": _cogvideox_modulated_input_extractor, + "CogVideoX": _cogvideox_modulated_input_extractor, +} + + def _auto_detect_extractor(module): - """Auto-detect and return appropriate extractor based on model type.""" - module_class_name = module.__class__.__name__ - if "Flux" in module_class_name: - return _flux_modulated_input_extractor - elif "Mochi" in module_class_name: - return _mochi_modulated_input_extractor - elif "Lumina2" in module_class_name: - return _lumina2_modulated_input_extractor - elif "CogVideoX" in module_class_name: - return _cogvideox_modulated_input_extractor - # Default to FLUX for backward compatibility - logger.warning( - f"TeaCache: Unknown model type {module_class_name}, defaulting to FLUX extractor. " - f"Results may be incorrect. Please provide extract_modulated_input_fn explicitly." - ) - return _flux_modulated_input_extractor + """Auto-detect and return appropriate extractor.""" + return _EXTRACTOR_REGISTRY[_auto_detect_model_type(module)] + + +def _auto_detect_model_type(module): + """Auto-detect model type from class name and config path.""" + class_name = module.__class__.__name__ + config_path = getattr(getattr(module, "config", None), "_name_or_path", "").lower() + + # Check config path first (for variants), then class name (ordered most specific first) + for model_type in _EXTRACTOR_REGISTRY: + if model_type.lower() in config_path or model_type in class_name: + if model_type not in _MODEL_COEFFICIENTS: + raise ValueError(f"TeaCache: No coefficients for '{model_type}'") + return model_type + + raise ValueError(f"TeaCache: Unsupported model '{class_name}'. Supported: {', '.join(_EXTRACTOR_REGISTRY)}") + + +def _get_model_coefficients(model_type): + """Get polynomial coefficients for a specific model type. + + Args: + model_type: Model type string (e.g., "Flux", "Mochi") + + Raises: + ValueError: If coefficients not found for model type. + """ + if model_type not in _MODEL_COEFFICIENTS: + available_models = ", ".join(_MODEL_COEFFICIENTS.keys()) + raise ValueError( + f"TeaCache: No coefficients found for model type '{model_type}'. " + f"Available models: {available_models}. " + f"Please provide coefficients explicitly in TeaCacheConfig." + ) + return _MODEL_COEFFICIENTS[model_type] @dataclass @@ -98,8 +135,8 @@ class TeaCacheConfig: by reusing transformer block computations when consecutive timestep embeddings are similar. It uses polynomial rescaling of L1 distances between modulated inputs to intelligently decide when to cache. - Currently supports: FLUX, Mochi, Lumina2, and CogVideoX models. Model type is auto-detected, and model-specific - polynomial coefficients are automatically applied. + Currently supports: FLUX, FLUX-Kontext, Mochi, Lumina2, and CogVideoX models. Model type is auto-detected, and + model-specific polynomial coefficients are automatically applied. Args: rel_l1_thresh (`float`, defaults to `0.2`): @@ -172,20 +209,14 @@ def __post_init__(self): f"Try 0.25 for 1.5x speedup or 0.6 for 2x speedup." ) if self.rel_l1_thresh < 0.05: - import warnings - - warnings.warn( + logger.warning( f"rel_l1_thresh={self.rel_l1_thresh} is very low and may result in minimal caching. " - f"Consider using values between 0.1 and 0.3 for optimal performance.", - UserWarning, + f"Consider using values between 0.1 and 0.3 for optimal performance." ) if self.rel_l1_thresh > 1.0: - import warnings - - warnings.warn( + logger.warning( f"rel_l1_thresh={self.rel_l1_thresh} is very high and may cause quality degradation. " - f"Consider using values between 0.1 and 0.6 for better quality-speed tradeoff.", - UserWarning, + f"Consider using values between 0.1 and 0.6 for better quality-speed tradeoff." ) # Set default coefficients if not provided (will be auto-detected in hook initialization) @@ -326,47 +357,31 @@ def __init__(self, config: TeaCacheConfig): def initialize_hook(self, module): self.state_manager.set_context("teacache") - # Auto-detect model type and extractor - module_class_name = module.__class__.__name__ + # Strict auto-detection if self.config.extract_modulated_input_fn is None: - self.extractor_fn = _auto_detect_extractor(module) - # Detect model type for coefficient auto-detection - if "Flux" in module_class_name: - self.model_type = "Flux" - elif "Mochi" in module_class_name: - self.model_type = "Mochi" - elif "Lumina2" in module_class_name: - self.model_type = "Lumina2" - elif "CogVideoX" in module_class_name: - # Try to detect specific CogVideoX variant from config - if hasattr(module, "config") and hasattr(module.config, "_name_or_path"): - name_or_path = module.config._name_or_path.lower() - if "1.5" in name_or_path or "1-5" in name_or_path: - if "i2v" in name_or_path: - self.model_type = "CogVideoX1.5-5B-I2V" - else: - self.model_type = "CogVideoX1.5-5B" - elif "2b" in name_or_path: - self.model_type = "CogVideoX-2b" - elif "5b" in name_or_path: - self.model_type = "CogVideoX-5b" - else: - self.model_type = "CogVideoX" # Default to generic 5b - else: - self.model_type = "CogVideoX" # Default to generic 5b - else: - self.model_type = "Flux" # Default + self.extractor_fn = _auto_detect_extractor(module) # Raises if unsupported + self.model_type = _auto_detect_model_type(module) # Raises if unsupported else: self.extractor_fn = self.config.extract_modulated_input_fn - self.model_type = "Flux" # Default if custom extractor provided - - # Auto-set coefficients from registry if not explicitly provided - if self.config.coefficients is None or ( - self.model_type in _MODEL_COEFFICIENTS and self.config.coefficients == _MODEL_COEFFICIENTS.get("Flux") - ): - if self.model_type in _MODEL_COEFFICIENTS: - self.config.coefficients = _MODEL_COEFFICIENTS[self.model_type] - logger.info(f"TeaCache: Auto-detected {self.model_type} coefficients") + # Still try to detect model type for coefficients + try: + self.model_type = _auto_detect_model_type(module) + except ValueError: + self.model_type = None # User provided custom extractor + logger.warning( + f"TeaCache: Using custom extractor for {module.__class__.__name__}. " + f"Coefficients must be provided explicitly." + ) + + # Strict coefficient matching + if self.config.coefficients is None: + if self.model_type is None: + raise ValueError( + "TeaCache: Cannot auto-detect coefficients when using custom extractor. " + "Please provide coefficients explicitly in TeaCacheConfig." + ) + self.config.coefficients = _get_model_coefficients(self.model_type) # Raises if not found + logger.info(f"TeaCache: Using {self.model_type} coefficients") # Initialize rescale function with final coefficients self.rescale_func = np.poly1d(self.config.coefficients) From f0abb3ca3728042516e7960a158e18293287f4b1 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Wed, 10 Dec 2025 05:00:39 +0000 Subject: [PATCH 12/25] fix(teacache): fix return_dict handling, CogVideoX fallback bug, add torch.compile support, and clean up coefficient flow Signed-off-by: Prajwal A --- src/diffusers/hooks/teacache.py | 205 +++++++++++++++----------------- tests/hooks/test_teacache.py | 45 ++++--- 2 files changed, 117 insertions(+), 133 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index 45d2936604e0..08661e5e1296 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -15,7 +15,6 @@ from dataclasses import dataclass from typing import Callable, List, Optional -import numpy as np import torch from ..utils import logging @@ -219,29 +218,23 @@ def __post_init__(self): f"Consider using values between 0.1 and 0.6 for better quality-speed tradeoff." ) - # Set default coefficients if not provided (will be auto-detected in hook initialization) - if self.coefficients is None: - # Default to FLUX coefficients (will be overridden by auto-detection if model is recognized) - self.coefficients = _MODEL_COEFFICIENTS.get( - "Flux", [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01] - ) - - # Validate coefficients - if not isinstance(self.coefficients, (list, tuple)): - raise TypeError( - f"coefficients must be a list or tuple, got {type(self.coefficients).__name__}. " - f"Please provide a list of 5 polynomial coefficients." - ) - if len(self.coefficients) != 5: - raise ValueError( - f"coefficients must contain exactly 5 elements for 4th-degree polynomial, " - f"got {len(self.coefficients)}. The polynomial is evaluated as: " - f"c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]" - ) - if not all(isinstance(c, (int, float)) for c in self.coefficients): - raise TypeError( - f"All coefficients must be numbers. Got types: {[type(c).__name__ for c in self.coefficients]}" - ) + # Validate coefficients only if explicitly provided (None = auto-detect later) + if self.coefficients is not None: + if not isinstance(self.coefficients, (list, tuple)): + raise TypeError( + f"coefficients must be a list or tuple, got {type(self.coefficients).__name__}. " + f"Please provide a list of 5 polynomial coefficients." + ) + if len(self.coefficients) != 5: + raise ValueError( + f"coefficients must contain exactly 5 elements for 4th-degree polynomial, " + f"got {len(self.coefficients)}. The polynomial is evaluated as: " + f"c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]" + ) + if not all(isinstance(c, (int, float)) for c in self.coefficients): + raise TypeError( + f"All coefficients must be numbers. Got types: {[type(c).__name__ for c in self.coefficients]}" + ) def __repr__(self) -> str: return ( @@ -349,11 +342,50 @@ def __init__(self, config: TeaCacheConfig): # Set default rescale_func with config coefficients (will be updated in initialize_hook if needed) # This ensures rescale_func is always valid, even if initialize_hook isn't called (e.g., in tests) default_coeffs = config.coefficients if config.coefficients else _MODEL_COEFFICIENTS["Flux"] - self.rescale_func = np.poly1d(default_coeffs) + self.coefficients = default_coeffs + self.rescale_func = self._create_rescale_func(default_coeffs) self.state_manager = StateManager(TeaCacheState, (), {}) self.extractor_fn = None self.model_type = None + @staticmethod + def _create_rescale_func(coefficients): + """Create polynomial rescale function from coefficients. + + Evaluates: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4] + """ + def rescale(x): + return coefficients[0] * x**4 + coefficients[1] * x**3 + coefficients[2] * x**2 + coefficients[3] * x + coefficients[4] + return rescale + + def _maybe_reset_state_for_new_inference(self, state, module, reset_encoder_residual=False): + """Reset state if we've completed all steps (start of new inference run). + + Also initializes num_steps on first timestep if not set. + + Args: + state: TeaCacheState instance. + module: The transformer module. + reset_encoder_residual: If True, also reset previous_residual_encoder (for CogVideoX). + """ + # Reset counter if we've completed all steps (new inference run) + if state.cnt == state.num_steps and state.num_steps > 0: + logger.info("TeaCache inference completed") + state.cnt = 0 + state.accumulated_rel_l1_distance = 0.0 + state.previous_modulated_input = None + state.previous_residual = None + if reset_encoder_residual: + state.previous_residual_encoder = None + + # Set num_steps on first timestep if not already set + if state.cnt == 0 and state.num_steps == 0: + if self.config.num_inference_steps_callback is not None: + state.num_steps = self.config.num_inference_steps_callback() + # If still not set, try to get from module attribute (set by pipeline) + if state.num_steps == 0 and hasattr(module, "num_steps"): + state.num_steps = module.num_steps + def initialize_hook(self, module): self.state_manager.set_context("teacache") @@ -373,18 +405,21 @@ def initialize_hook(self, module): f"Coefficients must be provided explicitly." ) - # Strict coefficient matching + # Auto-detect coefficients if not provided by user if self.config.coefficients is None: if self.model_type is None: raise ValueError( "TeaCache: Cannot auto-detect coefficients when using custom extractor. " "Please provide coefficients explicitly in TeaCacheConfig." ) - self.config.coefficients = _get_model_coefficients(self.model_type) # Raises if not found + self.coefficients = _get_model_coefficients(self.model_type) # Raises if not found logger.info(f"TeaCache: Using {self.model_type} coefficients") + else: + self.coefficients = self.config.coefficients + logger.info(f"TeaCache: Using user-provided coefficients") # Initialize rescale function with final coefficients - self.rescale_func = np.poly1d(self.config.coefficients) + self.rescale_func = self._create_rescale_func(self.coefficients) return module @@ -410,7 +445,16 @@ def new_forward(self, module, *args, **kwargs): return self._handle_flux_forward(module, *args, **kwargs) def _handle_flux_forward( - self, module, hidden_states, timestep, pooled_projections, encoder_hidden_states, txt_ids, img_ids, **kwargs + self, + module, + hidden_states, + timestep, + pooled_projections, + encoder_hidden_states, + txt_ids, + img_ids, + return_dict=True, + **kwargs, ): """ Handle FLUX transformer forward pass with TeaCache. @@ -423,28 +467,16 @@ def _handle_flux_forward( encoder_hidden_states (`torch.Tensor`): Text encoder outputs. txt_ids (`torch.Tensor`): Position IDs for text tokens. img_ids (`torch.Tensor`): Position IDs for image tokens. + return_dict (`bool`): Whether to return a dict. **kwargs: Additional arguments. Returns: - `torch.Tensor`: Denoised output tensor. + `torch.Tensor` or `Transformer2DModelOutput`: Denoised output. """ - state = self.state_manager.get_state() - - # Reset counter if we've completed all steps (new inference run) - if state.cnt == state.num_steps and state.num_steps > 0: - logger.info("TeaCache inference completed") - state.cnt = 0 - state.accumulated_rel_l1_distance = 0.0 - state.previous_modulated_input = None - state.previous_residual = None + from diffusers.models.modeling_outputs import Transformer2DModelOutput - # Set num_steps on first timestep if not already set - if state.cnt == 0 and state.num_steps == 0: - if self.config.num_inference_steps_callback is not None: - state.num_steps = self.config.num_inference_steps_callback() - # If still not set, try to get from module attribute (set by pipeline) - if state.num_steps == 0 and hasattr(module, "num_steps"): - state.num_steps = module.num_steps + state = self.state_manager.get_state() + self._maybe_reset_state_for_new_inference(state, module) # Process inputs like original TeaCache # Must process hidden_states through x_embedder first @@ -458,10 +490,8 @@ def _handle_flux_forward( else: temb = module.time_text_embed(timestep_scaled, pooled_projections) - # Extract modulated input using configured extractor - inp = hidden_states.clone() - temb_clone = temb.clone() - modulated_inp = self.extractor_fn(module, inp, temb_clone) + # Extract modulated input using configured extractor (extractors don't modify inputs) + modulated_inp = self.extractor_fn(module, hidden_states, temb) # Make caching decision should_calc = self._should_compute_full_transformer(state, modulated_inp) @@ -516,7 +546,9 @@ def _handle_flux_forward( hidden_states = module.norm_out(hidden_states, temb) output = module.proj_out(hidden_states) - return output + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) def _handle_mochi_forward( self, @@ -556,21 +588,7 @@ def _handle_mochi_forward( scale_lora_layers(module, lora_scale) state = self.state_manager.get_state() - - # Reset counter if we've completed all steps - if state.cnt == state.num_steps and state.num_steps > 0: - logger.info("TeaCache inference completed") - state.cnt = 0 - state.accumulated_rel_l1_distance = 0.0 - state.previous_modulated_input = None - state.previous_residual = None - - # Set num_steps on first timestep if not already set - if state.cnt == 0 and state.num_steps == 0: - if self.config.num_inference_steps_callback is not None: - state.num_steps = self.config.num_inference_steps_callback() - elif hasattr(module, "num_steps"): - state.num_steps = module.num_steps + self._maybe_reset_state_for_new_inference(state, module) batch_size, num_channels, num_frames, height, width = hidden_states.shape p = module.config.patch_size @@ -600,10 +618,8 @@ def _handle_mochi_forward( dtype=torch.float32, ) - # Extract modulated input - inp = hidden_states.clone() - temb_clone = temb.clone() - modulated_inp = self.extractor_fn(module, inp, temb_clone) + # Extract modulated input (extractors don't modify inputs) + modulated_inp = self.extractor_fn(module, hidden_states, temb) # Make caching decision should_calc = self._should_compute_full_transformer(state, modulated_inp) @@ -698,21 +714,7 @@ def _handle_lumina2_forward( scale_lora_layers(module, lora_scale) state = self.state_manager.get_state() - - # Reset counter if we've completed all steps - if state.cnt == state.num_steps and state.num_steps > 0: - logger.info("TeaCache inference completed") - state.cnt = 0 - state.accumulated_rel_l1_distance = 0.0 - state.previous_modulated_input = None - state.previous_residual = None - - # Set num_steps on first timestep if not already set - if state.cnt == 0 and state.num_steps == 0: - if self.config.num_inference_steps_callback is not None: - state.num_steps = self.config.num_inference_steps_callback() - elif hasattr(module, "num_steps"): - state.num_steps = module.num_steps + self._maybe_reset_state_for_new_inference(state, module) batch_size, _, height, width = hidden_states.shape @@ -840,22 +842,7 @@ def _handle_cogvideox_forward( scale_lora_layers(module, lora_scale) state = self.state_manager.get_state() - - # Reset counter if we've completed all steps - if state.cnt == state.num_steps and state.num_steps > 0: - logger.info("TeaCache inference completed") - state.cnt = 0 - state.accumulated_rel_l1_distance = 0.0 - state.previous_modulated_input = None - state.previous_residual = None - state.previous_residual_encoder = None - - # Set num_steps on first timestep if not already set - if state.cnt == 0 and state.num_steps == 0: - if self.config.num_inference_steps_callback is not None: - state.num_steps = self.config.num_inference_steps_callback() - elif hasattr(module, "num_steps"): - state.num_steps = module.num_steps + self._maybe_reset_state_for_new_inference(state, module, reset_encoder_residual=True) batch_size, num_frames, channels, height, width = hidden_states.shape @@ -885,16 +872,13 @@ def _handle_cogvideox_forward( # Make caching decision should_calc = self._should_compute_full_transformer(state, modulated_inp) - if not should_calc: - # Fast path: apply cached residuals (both encoder and hidden_states) - if state.previous_residual_encoder is not None: - hidden_states = hidden_states + state.previous_residual - encoder_hidden_states = encoder_hidden_states + state.previous_residual_encoder - else: - # Fallback: compute if encoder residual not cached - should_calc = True + # Fast path: apply cached residuals (both encoder and hidden_states) + # Must have both residuals cached to use fast path + if not should_calc and state.previous_residual_encoder is not None: + hidden_states = hidden_states + state.previous_residual + encoder_hidden_states = encoder_hidden_states + state.previous_residual_encoder else: - # Slow path: full computation + # Slow path: full computation (also runs when encoder residual not yet cached) ori_hidden_states = hidden_states.clone() ori_encoder_hidden_states = encoder_hidden_states.clone() @@ -944,6 +928,7 @@ def _handle_cogvideox_forward( return (output,) return Transformer2DModelOutput(sample=output) + @torch.compiler.disable def _should_compute_full_transformer(self, state, modulated_inp): """ Determine whether to compute full transformer blocks or reuse cached residual. diff --git a/tests/hooks/test_teacache.py b/tests/hooks/test_teacache.py index 5fb265509b5a..fd2ce6cfdf94 100644 --- a/tests/hooks/test_teacache.py +++ b/tests/hooks/test_teacache.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest -import warnings import torch @@ -27,8 +26,8 @@ def test_valid_config(self): """Test valid configuration is accepted.""" config = TeaCacheConfig(rel_l1_thresh=0.2) self.assertEqual(config.rel_l1_thresh, 0.2) - self.assertIsNotNone(config.coefficients) - self.assertEqual(len(config.coefficients), 5) + # coefficients is None by default (auto-detected during hook initialization) + self.assertIsNone(config.coefficients) def test_invalid_type(self): """Test invalid type for rel_l1_thresh raises TypeError.""" @@ -54,21 +53,17 @@ def test_invalid_coefficients_type(self): TeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, "invalid", 4.0, 5.0]) self.assertIn("must be numbers", str(context.exception)) - def test_warning_very_low_threshold(self): - """Test warning is issued for very low threshold.""" - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - TeaCacheConfig(rel_l1_thresh=0.01) - self.assertEqual(len(w), 1) - self.assertIn("very low", str(w[0].message)) - - def test_warning_very_high_threshold(self): - """Test warning is issued for very high threshold.""" - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - TeaCacheConfig(rel_l1_thresh=1.5) - self.assertEqual(len(w), 1) - self.assertIn("very high", str(w[0].message)) + def test_very_low_threshold_accepted(self): + """Test very low threshold is accepted (with logging warning).""" + # Very low threshold should be accepted but logged as warning + config = TeaCacheConfig(rel_l1_thresh=0.01) + self.assertEqual(config.rel_l1_thresh, 0.01) + + def test_very_high_threshold_accepted(self): + """Test very high threshold is accepted (with logging warning).""" + # Very high threshold should be accepted but logged as warning + config = TeaCacheConfig(rel_l1_thresh=1.5) + self.assertEqual(config.rel_l1_thresh, 1.5) def test_config_repr(self): """Test __repr__ method works correctly.""" @@ -188,9 +183,13 @@ def __init__(self): def custom_extractor(mod, hidden_states, temb): return hidden_states - config = TeaCacheConfig(rel_l1_thresh=0.2, extract_modulated_input_fn=custom_extractor) + # Must provide coefficients when using custom extractor (no auto-detection) + custom_coeffs = [1.0, 2.0, 3.0, 4.0, 5.0] + config = TeaCacheConfig( + rel_l1_thresh=0.2, extract_modulated_input_fn=custom_extractor, coefficients=custom_coeffs + ) - # Should not raise - TeaCache is now model-agnostic + # Should not raise - TeaCache works with custom extractor when coefficients provided apply_teacache(module, config) # Verify registry and disable path work @@ -341,7 +340,7 @@ def test_auto_detect_mochi(self): hook = registry.get_hook("teacache") self.assertIsNotNone(hook) # Verify coefficients were auto-set - self.assertEqual(hook.config.coefficients, _MODEL_COEFFICIENTS["Mochi"]) + self.assertEqual(hook.coefficients, _MODEL_COEFFICIENTS["Mochi"]) model.disable_cache() @@ -369,7 +368,7 @@ def test_auto_detect_lumina2(self): hook = registry.get_hook("teacache") self.assertIsNotNone(hook) # Verify coefficients were auto-set - self.assertEqual(hook.config.coefficients, _MODEL_COEFFICIENTS["Lumina2"]) + self.assertEqual(hook.coefficients, _MODEL_COEFFICIENTS["Lumina2"]) # Lumina2 doesn't have CacheMixin, manually remove hook instead registry.remove_hook("teacache") @@ -396,7 +395,7 @@ def test_auto_detect_cogvideox(self): hook = registry.get_hook("teacache") self.assertIsNotNone(hook) # Verify coefficients were auto-set - self.assertEqual(hook.config.coefficients, _MODEL_COEFFICIENTS["CogVideoX"]) + self.assertEqual(hook.coefficients, _MODEL_COEFFICIENTS["CogVideoX"]) model.disable_cache() From 4a6afef612a01985a14534f026564a70b11980cc Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Wed, 10 Dec 2025 06:03:00 +0000 Subject: [PATCH 13/25] Fix TeaCache state management and add num_inference_steps param Signed-off-by: Prajwal A --- src/diffusers/hooks/teacache.py | 24 ++++++++++++++++++------ src/diffusers/models/cache_utils.py | 15 +++++++++++++-- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index 08661e5e1296..6f71753b40d7 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -159,10 +159,13 @@ class TeaCacheConfig: current_timestep_callback (`Callable[[], int]`, *optional*, defaults to `None`): Callback function that returns the current timestep during inference. This is used internally for debugging and statistics tracking. If not provided, TeaCache will still function correctly. + num_inference_steps (`int`, *optional*, defaults to `None`): + Total number of inference steps. Required for proper state management - ensures first and last timesteps + are always computed (never cached) and that state resets between inference runs. If not provided, + TeaCache will attempt to detect via callback or module attribute. num_inference_steps_callback (`Callable[[], int]`, *optional*, defaults to `None`): - Callback function that returns the total number of inference steps. This is used to ensure the first and - last timesteps are always computed (never cached) for maximum quality. If not provided, TeaCache will - attempt to detect the number of steps automatically from the pipeline. + Callback function that returns the total number of inference steps. Alternative to `num_inference_steps` + for dynamic step counts. Examples: ```python @@ -192,6 +195,7 @@ class TeaCacheConfig: coefficients: Optional[List[float]] = None extract_modulated_input_fn: Optional[Callable] = None current_timestep_callback: Optional[Callable[[], int]] = None + num_inference_steps: Optional[int] = None num_inference_steps_callback: Optional[Callable[[], int]] = None def __post_init__(self): @@ -380,11 +384,16 @@ def _maybe_reset_state_for_new_inference(self, state, module, reset_encoder_resi # Set num_steps on first timestep if not already set if state.cnt == 0 and state.num_steps == 0: - if self.config.num_inference_steps_callback is not None: + # Priority: config value > callback > module attribute + if self.config.num_inference_steps is not None: + state.num_steps = self.config.num_inference_steps + elif self.config.num_inference_steps_callback is not None: state.num_steps = self.config.num_inference_steps_callback() - # If still not set, try to get from module attribute (set by pipeline) - if state.num_steps == 0 and hasattr(module, "num_steps"): + elif hasattr(module, "num_steps"): state.num_steps = module.num_steps + + if state.num_steps > 0: + logger.info(f"TeaCache: Using {state.num_steps} inference steps") def initialize_hook(self, module): self.state_manager.set_context("teacache") @@ -975,6 +984,9 @@ def _should_compute_full_transformer(self, state, modulated_inp): rescaled_distance = self.rescale_func(rel_distance) state.accumulated_rel_l1_distance += rescaled_distance + # Debug logging (uncomment to debug) + # logger.warning(f"Step {state.cnt}: rel_l1={rel_distance:.6f}, rescaled={rescaled_distance:.6f}, accumulated={state.accumulated_rel_l1_distance:.6f}, thresh={self.config.rel_l1_thresh}") + # Make decision based on accumulated threshold if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh: return False diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 0811159fcace..03a33e32784d 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -154,11 +154,22 @@ def cache_context(self, name: str): registry._set_context(None) - def enable_teacache(self, rel_l1_thresh: float = 0.2, **kwargs): + def enable_teacache(self, rel_l1_thresh: float = 0.2, num_inference_steps: int = None, **kwargs): r""" Enable TeaCache on the model. + + Args: + rel_l1_thresh (`float`, defaults to `0.2`): + Threshold for caching decision. Higher = more aggressive caching. + num_inference_steps (`int`, *optional*): + Total number of inference steps. Required for proper state management. + **kwargs: Additional arguments passed to TeaCacheConfig. """ from ..hooks import TeaCacheConfig - config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh, **kwargs) + config = TeaCacheConfig( + rel_l1_thresh=rel_l1_thresh, + num_inference_steps=num_inference_steps, + **kwargs + ) self.enable_cache(config) From 86469088e0e2c37aa0fed35dbd5029b90f92dbbe Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Wed, 10 Dec 2025 11:58:45 +0000 Subject: [PATCH 14/25] fixed counter manage, cogvoideox missing norm proj added Signed-off-by: Prajwal A --- src/diffusers/hooks/teacache.py | 102 ++++++++++++++---- .../transformers/transformer_lumina2.py | 3 +- 2 files changed, 85 insertions(+), 20 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index 6f71753b40d7..2640367a17dc 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -289,6 +289,9 @@ def __init__(self): self.previous_residual = None # For models that cache both encoder and hidden_states residuals (e.g., CogVideoX) self.previous_residual_encoder = None + # For models with variable sequence lengths (e.g., Lumina2) + self.cache_dict = {} + self.uncond_seq_len = None def reset(self): """Reset all state variables to initial values for a new inference run.""" @@ -298,6 +301,8 @@ def reset(self): self.previous_modulated_input = None self.previous_residual = None self.previous_residual_encoder = None + self.cache_dict = {} + self.uncond_seq_len = None def __repr__(self) -> str: return ( @@ -634,7 +639,7 @@ def _handle_mochi_forward( should_calc = self._should_compute_full_transformer(state, modulated_inp) if not should_calc: - # Fast path: apply cached residual + # Fast path: apply cached residual (already includes norm_out) hidden_states = hidden_states + state.previous_residual else: # Slow path: full computation @@ -660,14 +665,16 @@ def _handle_mochi_forward( image_rotary_emb=image_rotary_emb, ) - # Cache the residual + # Apply norm_out before caching residual (matches reference implementation) + hidden_states = module.norm_out(hidden_states, temb) + + # Cache the residual (includes norm_out transformation) state.previous_residual = hidden_states - ori_hidden_states state.previous_modulated_input = modulated_inp state.cnt += 1 - # Apply final norm and projection - hidden_states = module.norm_out(hidden_states, temb) + # Apply projection hidden_states = module.proj_out(hidden_states) # Reshape output @@ -765,12 +772,57 @@ def _handle_lumina2_forward( # Extract modulated input (after preprocessing) modulated_inp = self.extractor_fn(module, input_to_main_loop, temb) - # Make caching decision - should_calc = self._should_compute_full_transformer(state, modulated_inp) - - if not should_calc: + # Per-sequence-length cache for Lumina2 (handles variable sequence lengths) + cache_key = max_seq_len + if cache_key not in state.cache_dict: + state.cache_dict[cache_key] = { + "previous_modulated_input": None, + "previous_residual": None, + "accumulated_rel_l1_distance": 0.0, + } + current_cache = state.cache_dict[cache_key] + + # Make caching decision using per-cache values + if state.cnt == 0 or state.cnt == state.num_steps - 1: + should_calc = True + current_cache["accumulated_rel_l1_distance"] = 0.0 + else: + if current_cache["previous_modulated_input"] is not None: + prev_mod_input = current_cache["previous_modulated_input"] + prev_mean = prev_mod_input.abs().mean() + + if prev_mean.item() > 1e-9: + rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item() + else: + rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf') + + rescaled_distance = self.rescale_func(rel_l1_change) + current_cache["accumulated_rel_l1_distance"] += rescaled_distance + + if current_cache["accumulated_rel_l1_distance"] < self.config.rel_l1_thresh: + should_calc = False + else: + should_calc = True + current_cache["accumulated_rel_l1_distance"] = 0.0 + else: + should_calc = True + current_cache["accumulated_rel_l1_distance"] = 0.0 + + current_cache["previous_modulated_input"] = modulated_inp.clone() + + # Track unconditional sequence length for counter management + if state.uncond_seq_len is None: + state.uncond_seq_len = cache_key + # Only increment counter when not processing unconditional (different seq len) + if cache_key != state.uncond_seq_len: + state.cnt += 1 + if state.cnt >= state.num_steps: + state.cnt = 0 + + # Fast or slow path with per-cache residual + if not should_calc and current_cache["previous_residual"] is not None: # Fast path: apply cached residual - processed_hidden_states = input_to_main_loop + state.previous_residual + processed_hidden_states = input_to_main_loop + current_cache["previous_residual"] else: # Slow path: full computation current_processing_states = input_to_main_loop @@ -779,11 +831,8 @@ def _handle_lumina2_forward( current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb ) processed_hidden_states = current_processing_states - # Cache the residual - state.previous_residual = processed_hidden_states - input_to_main_loop - - state.previous_modulated_input = modulated_inp - state.cnt += 1 + # Cache the residual in per-cache storage + current_cache["previous_residual"] = processed_hidden_states - input_to_main_loop # Apply final norm and reshape output_after_norm = module.norm_out(processed_hidden_states, temb) @@ -881,13 +930,13 @@ def _handle_cogvideox_forward( # Make caching decision should_calc = self._should_compute_full_transformer(state, modulated_inp) - # Fast path: apply cached residuals (both encoder and hidden_states) - # Must have both residuals cached to use fast path - if not should_calc and state.previous_residual_encoder is not None: + # Fast or slow path based on caching decision + if not should_calc: + # Fast path: apply cached residuals (both encoder and hidden_states) hidden_states = hidden_states + state.previous_residual encoder_hidden_states = encoder_hidden_states + state.previous_residual_encoder else: - # Slow path: full computation (also runs when encoder residual not yet cached) + # Slow path: full computation ori_hidden_states = hidden_states.clone() ori_encoder_hidden_states = encoder_hidden_states.clone() @@ -928,7 +977,22 @@ def _handle_cogvideox_forward( hidden_states = module.norm_final(hidden_states) hidden_states = hidden_states[:, text_seq_length:] - output = module.proj_out(hidden_states) + # Final block + hidden_states = module.norm_out(hidden_states, temb=emb) + hidden_states = module.proj_out(hidden_states) + + # Unpatchify + p = module.config.patch_size + p_t = module.config.patch_size_t + + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) if USE_PEFT_BACKEND: unscale_lora_layers(module, lora_scale) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 77121edb9fc9..a08d33dbfa77 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -25,6 +25,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import LuminaFeedForward from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -322,7 +323,7 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths -class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" Lumina2NextDiT: Diffusion model with a Transformer backbone. From 38effa1f8f7cf0cdb6f55617217ac2c912b8c8e0 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Tue, 16 Dec 2025 08:30:22 +0000 Subject: [PATCH 15/25] Refactor TeaCache hook into adapters Signed-off-by: Prajwal A --- src/diffusers/hooks/teacache.py | 590 ++++++++++++++------------------ 1 file changed, 253 insertions(+), 337 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index 2640367a17dc..f525fb343922 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -437,66 +437,144 @@ def initialize_hook(self, module): return module - def new_forward(self, module, *args, **kwargs): + @torch.compiler.disable + def _run_single_residual_cache( + self, + state: "TeaCacheState", + modulated_inp: torch.Tensor, + compute_fn: Callable[[], Tuple[torch.Tensor, torch.Tensor]], + cached_fn: Callable[[], torch.Tensor], + ) -> torch.Tensor: """ - Route to model-specific forward handler based on detected model type. + Shared cache engine for models where we cache exactly one residual tensor. + + `compute_fn` must return `(y, x_base)` where residual is `y - x_base`. + `cached_fn` must return `y` computed from the current `x` and cached residual. """ - module_class_name = module.__class__.__name__ + should_calc = self._should_compute_full_transformer(state, modulated_inp) + if (not should_calc) and (state.previous_residual is None): + should_calc = True - if "Flux" in module_class_name: - return self._handle_flux_forward(module, *args, **kwargs) - elif "Mochi" in module_class_name: - return self._handle_mochi_forward(module, *args, **kwargs) - elif "Lumina2" in module_class_name: - return self._handle_lumina2_forward(module, *args, **kwargs) - elif "CogVideoX" in module_class_name: - return self._handle_cogvideox_forward(module, *args, **kwargs) + if should_calc: + y, x_base = compute_fn() + state.previous_residual = y - x_base else: - # Default to FLUX handler for backward compatibility - logger.warning( - f"TeaCache: Unknown model type {module_class_name}, using FLUX handler. Results may be incorrect." - ) - return self._handle_flux_forward(module, *args, **kwargs) + y = cached_fn() + + state.previous_modulated_input = modulated_inp + state.cnt += 1 + return y - def _handle_flux_forward( + @torch.compiler.disable + def _run_dual_residual_cache( self, - module, - hidden_states, - timestep, - pooled_projections, - encoder_hidden_states, - txt_ids, - img_ids, - return_dict=True, - **kwargs, - ): + state: "TeaCacheState", + modulated_inp: torch.Tensor, + compute_fn: Callable[[], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], + cached_fn: Callable[[], Tuple[torch.Tensor, torch.Tensor]], + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Handle FLUX transformer forward pass with TeaCache. + Shared cache engine for models where we cache two residual tensors: + - hidden_states residual + - encoder_hidden_states residual - Args: - module: The FluxTransformer2DModel instance. - hidden_states (`torch.Tensor`): Input latent tensor. - timestep (`torch.Tensor`): Current diffusion timestep. - pooled_projections (`torch.Tensor`): Pooled text embeddings. - encoder_hidden_states (`torch.Tensor`): Text encoder outputs. - txt_ids (`torch.Tensor`): Position IDs for text tokens. - img_ids (`torch.Tensor`): Position IDs for image tokens. - return_dict (`bool`): Whether to return a dict. - **kwargs: Additional arguments. - - Returns: - `torch.Tensor` or `Transformer2DModelOutput`: Denoised output. + `compute_fn` must return `(hs_y, enc_y, hs_x_base, enc_x_base)`. + `cached_fn` must return `(hs_y, enc_y)` computed using cached residuals. """ + should_calc = self._should_compute_full_transformer(state, modulated_inp) + if (not should_calc) and (state.previous_residual is None or state.previous_residual_encoder is None): + should_calc = True + + if should_calc: + hs_y, enc_y, hs_x_base, enc_x_base = compute_fn() + state.previous_residual = hs_y - hs_x_base + state.previous_residual_encoder = enc_y - enc_x_base + else: + hs_y, enc_y = cached_fn() + + state.previous_modulated_input = modulated_inp + state.cnt += 1 + return hs_y, enc_y + + def new_forward(self, module, *args, **kwargs): + module_class_name = module.__class__.__name__ + if "Flux" in module_class_name: + return _FluxTeaCacheAdapter.forward(self, module, *args, **kwargs) + if "Mochi" in module_class_name: + return _MochiTeaCacheAdapter.forward(self, module, *args, **kwargs) + if "Lumina2" in module_class_name: + return _Lumina2TeaCacheAdapter.forward(self, module, *args, **kwargs) + if "CogVideoX" in module_class_name: + return _CogVideoXTeaCacheAdapter.forward(self, module, *args, **kwargs) + + logger.warning(f"TeaCache: Unsupported model type {module_class_name}.") + return self.fn_ref.original_forward(*args, **kwargs) + + @torch.compiler.disable + def _should_compute_full_transformer(self, state, modulated_inp): + """ + Determine whether to compute full transformer blocks or reuse cached residual. + + This method implements the core caching decision logic from the TeaCache paper: + - Always compute first and last timesteps (for maximum quality) + - For intermediate timesteps, compute relative L1 distance between current and previous modulated inputs + - Apply polynomial rescaling to convert distance to model-specific caching signal + - Accumulate rescaled distances and compare to threshold + - Return True (compute) if accumulated distance exceeds threshold, False (cache) otherwise + """ + if state.cnt == 0: + state.accumulated_rel_l1_distance = 0 + return True + + if state.num_steps > 0 and state.cnt == state.num_steps - 1: + state.accumulated_rel_l1_distance = 0 + return True + + if state.previous_modulated_input is None: + return True + + rel_distance = ( + ( + (modulated_inp - state.previous_modulated_input).abs().mean() + / state.previous_modulated_input.abs().mean() + ) + .cpu() + .item() + ) + rescaled_distance = self.rescale_func(rel_distance) + state.accumulated_rel_l1_distance += rescaled_distance + + if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh: + return False + + state.accumulated_rel_l1_distance = 0 + return True + + def reset_state(self, module): + self.state_manager.reset() + return module + +class _FluxTeaCacheAdapter: + @staticmethod + def forward( + hook: TeaCacheHook, + module: torch.nn.Module, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + pooled_projections: torch.Tensor, + encoder_hidden_states: torch.Tensor, + txt_ids: torch.Tensor, + img_ids: torch.Tensor, + return_dict: bool = True, + **kwargs, + ): from diffusers.models.modeling_outputs import Transformer2DModelOutput - state = self.state_manager.get_state() - self._maybe_reset_state_for_new_inference(state, module) + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module) - # Process inputs like original TeaCache - # Must process hidden_states through x_embedder first hidden_states = module.x_embedder(hidden_states) - # Extract timestep embedding timestep_scaled = timestep.to(hidden_states.dtype) * 1000 if kwargs.get("guidance") is not None: guidance = kwargs["guidance"].to(hidden_states.dtype) * 1000 @@ -504,59 +582,50 @@ def _handle_flux_forward( else: temb = module.time_text_embed(timestep_scaled, pooled_projections) - # Extract modulated input using configured extractor (extractors don't modify inputs) - modulated_inp = self.extractor_fn(module, hidden_states, temb) + modulated_inp = hook.extractor_fn(module, hidden_states, temb) + joint_attention_kwargs = kwargs.get("joint_attention_kwargs") - # Make caching decision - should_calc = self._should_compute_full_transformer(state, modulated_inp) + def cached_fn(): + return hidden_states + state.previous_residual - if not should_calc: - # Fast path: apply cached residual - hidden_states = hidden_states + state.previous_residual - else: - # Slow path: full computation inline (like original TeaCache) - ori_hidden_states = hidden_states.clone() + def compute_fn(): + hs = hidden_states + ori_hs = hs.clone() - # Process encoder_hidden_states - encoder_hidden_states = module.context_embedder(encoder_hidden_states) + enc = module.context_embedder(encoder_hidden_states) - # Process txt_ids and img_ids if txt_ids.ndim == 3: - txt_ids = txt_ids[0] + txt = txt_ids[0] + else: + txt = txt_ids if img_ids.ndim == 3: - img_ids = img_ids[0] + img = img_ids[0] + else: + img = img_ids - ids = torch.cat((txt_ids, img_ids), dim=0) + ids = torch.cat((txt, img), dim=0) image_rotary_emb = module.pos_embed(ids) - # Process through transformer blocks for block in module.transformer_blocks: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, + enc, hs = block( + hidden_states=hs, + encoder_hidden_states=enc, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=kwargs.get("joint_attention_kwargs"), + joint_attention_kwargs=joint_attention_kwargs, ) - - # Process through single transformer blocks - # Note: single blocks concatenate internally, so pass separately for block in module.single_transformer_blocks: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, + enc, hs = block( + hidden_states=hs, + encoder_hidden_states=enc, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=kwargs.get("joint_attention_kwargs"), + joint_attention_kwargs=joint_attention_kwargs, ) + return hs, ori_hs - # Cache the residual - state.previous_residual = hidden_states - ori_hidden_states - - state.previous_modulated_input = modulated_inp - state.cnt += 1 + hidden_states = hook._run_single_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) - # Apply final norm and projection (always needed) hidden_states = module.norm_out(hidden_states, temb) output = module.proj_out(hidden_states) @@ -564,31 +633,19 @@ def _handle_flux_forward( return (output,) return Transformer2DModelOutput(sample=output) - def _handle_mochi_forward( - self, - module, - hidden_states, - encoder_hidden_states, - timestep, - encoder_attention_mask, - attention_kwargs=None, - return_dict=True, - ): - """ - Handle Mochi transformer forward pass with TeaCache. - Args: - module: The MochiTransformer3DModel instance. - hidden_states (`torch.Tensor`): Input latent tensor. - encoder_hidden_states (`torch.Tensor`): Text encoder outputs. - timestep (`torch.Tensor`): Current diffusion timestep. - encoder_attention_mask (`torch.Tensor`): Attention mask for encoder. - attention_kwargs (`dict`, optional): Additional attention arguments. - return_dict (`bool`): Whether to return a dict. - - Returns: - `torch.Tensor` or `Transformer2DModelOutput`: Denoised output. - """ +class _MochiTeaCacheAdapter: + @staticmethod + def forward( + hook: TeaCacheHook, + module: torch.nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_attention_mask: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers @@ -601,15 +658,14 @@ def _handle_mochi_forward( if USE_PEFT_BACKEND: scale_lora_layers(module, lora_scale) - state = self.state_manager.get_state() - self._maybe_reset_state_for_new_inference(state, module) + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module) - batch_size, num_channels, num_frames, height, width = hidden_states.shape + batch_size, _, num_frames, height, width = hidden_states.shape p = module.config.patch_size post_patch_height = height // p post_patch_width = width // p - # Process time embedding temb, encoder_hidden_states = module.time_embed( timestep, encoder_hidden_states, @@ -617,12 +673,10 @@ def _handle_mochi_forward( hidden_dtype=hidden_states.dtype, ) - # Process patch embedding hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = module.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) - # Get rotary embeddings image_rotary_emb = module.rope( module.pos_frequencies, num_frames, @@ -632,52 +686,39 @@ def _handle_mochi_forward( dtype=torch.float32, ) - # Extract modulated input (extractors don't modify inputs) - modulated_inp = self.extractor_fn(module, hidden_states, temb) + modulated_inp = hook.extractor_fn(module, hidden_states, temb) - # Make caching decision - should_calc = self._should_compute_full_transformer(state, modulated_inp) + def cached_fn(): + return hidden_states + state.previous_residual - if not should_calc: - # Fast path: apply cached residual (already includes norm_out) - hidden_states = hidden_states + state.previous_residual - else: - # Slow path: full computation - ori_hidden_states = hidden_states.clone() - - # Process through transformer blocks + def compute_fn(): + hs = hidden_states + ori_hs = hs.clone() + enc = encoder_hidden_states for block in module.transformer_blocks: if torch.is_grad_enabled() and module.gradient_checkpointing: - hidden_states, encoder_hidden_states = module._gradient_checkpointing_func( + hs, enc = module._gradient_checkpointing_func( block, - hidden_states, - encoder_hidden_states, + hs, + enc, temb, encoder_attention_mask, image_rotary_emb, ) else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, + hs, enc = block( + hidden_states=hs, + encoder_hidden_states=enc, temb=temb, encoder_attention_mask=encoder_attention_mask, image_rotary_emb=image_rotary_emb, ) + hs = module.norm_out(hs, temb) + return hs, ori_hs - # Apply norm_out before caching residual (matches reference implementation) - hidden_states = module.norm_out(hidden_states, temb) - - # Cache the residual (includes norm_out transformation) - state.previous_residual = hidden_states - ori_hidden_states - - state.previous_modulated_input = modulated_inp - state.cnt += 1 + hidden_states = hook._run_single_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) - # Apply projection hidden_states = module.proj_out(hidden_states) - - # Reshape output hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) output = hidden_states.reshape(batch_size, -1, num_frames, height, width) @@ -689,34 +730,20 @@ def _handle_mochi_forward( return (output,) return Transformer2DModelOutput(sample=output) - def _handle_lumina2_forward( - self, - module, - hidden_states, - timestep, - encoder_hidden_states, - encoder_attention_mask, - attention_kwargs=None, - return_dict=True, - ): - """ - Handle Lumina2 transformer forward pass with TeaCache. - - Note: Lumina2 has complex preprocessing and uses 'layers' instead of 'transformer_blocks'. - The modulated input extraction happens after preprocessing to input_to_main_loop. - Args: - module: The Lumina2Transformer2DModel instance. - hidden_states (`torch.Tensor`): Input latent tensor. - timestep (`torch.Tensor`): Current diffusion timestep. - encoder_hidden_states (`torch.Tensor`): Text encoder outputs. - encoder_attention_mask (`torch.Tensor`): Attention mask for encoder. - attention_kwargs (`dict`, optional): Additional attention arguments. - return_dict (`bool`): Whether to return a dict. - - Returns: - `torch.Tensor` or `Transformer2DModelOutput`: Denoised output. - """ +class _Lumina2TeaCacheAdapter: + @staticmethod + def forward( + hook: TeaCacheHook, + module: torch.nn.Module, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + # Keep Lumina2 logic isolated (variable seq lens + per-len caches). from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers @@ -729,15 +756,12 @@ def _handle_lumina2_forward( if USE_PEFT_BACKEND: scale_lora_layers(module, lora_scale) - state = self.state_manager.get_state() - self._maybe_reset_state_for_new_inference(state, module) + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module) batch_size, _, height, width = hidden_states.shape - # Lumina2 preprocessing (matches original forward) - temb, encoder_hidden_states_processed = module.time_caption_embed( - hidden_states, timestep, encoder_hidden_states - ) + temb, encoder_hidden_states_processed = module.time_caption_embed(hidden_states, timestep, encoder_hidden_states) ( image_patch_embeddings, context_rotary_emb, @@ -769,10 +793,8 @@ def _handle_lumina2_forward( mask[i, :seq_len_val] = True attention_mask_for_main_loop_arg = mask - # Extract modulated input (after preprocessing) - modulated_inp = self.extractor_fn(module, input_to_main_loop, temb) + modulated_inp = hook.extractor_fn(module, input_to_main_loop, temb) - # Per-sequence-length cache for Lumina2 (handles variable sequence lengths) cache_key = max_seq_len if cache_key not in state.cache_dict: state.cache_dict[cache_key] = { @@ -782,7 +804,6 @@ def _handle_lumina2_forward( } current_cache = state.cache_dict[cache_key] - # Make caching decision using per-cache values if state.cnt == 0 or state.cnt == state.num_steps - 1: should_calc = True current_cache["accumulated_rel_l1_distance"] = 0.0 @@ -790,16 +811,13 @@ def _handle_lumina2_forward( if current_cache["previous_modulated_input"] is not None: prev_mod_input = current_cache["previous_modulated_input"] prev_mean = prev_mod_input.abs().mean() - if prev_mean.item() > 1e-9: rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item() else: - rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf') - - rescaled_distance = self.rescale_func(rel_l1_change) + rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float("inf") + rescaled_distance = hook.rescale_func(rel_l1_change) current_cache["accumulated_rel_l1_distance"] += rescaled_distance - - if current_cache["accumulated_rel_l1_distance"] < self.config.rel_l1_thresh: + if current_cache["accumulated_rel_l1_distance"] < hook.config.rel_l1_thresh: should_calc = False else: should_calc = True @@ -810,31 +828,24 @@ def _handle_lumina2_forward( current_cache["previous_modulated_input"] = modulated_inp.clone() - # Track unconditional sequence length for counter management if state.uncond_seq_len is None: state.uncond_seq_len = cache_key - # Only increment counter when not processing unconditional (different seq len) if cache_key != state.uncond_seq_len: state.cnt += 1 if state.cnt >= state.num_steps: state.cnt = 0 - # Fast or slow path with per-cache residual if not should_calc and current_cache["previous_residual"] is not None: - # Fast path: apply cached residual processed_hidden_states = input_to_main_loop + current_cache["previous_residual"] else: - # Slow path: full computation current_processing_states = input_to_main_loop for layer in module.layers: current_processing_states = layer( current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb ) processed_hidden_states = current_processing_states - # Cache the residual in per-cache storage current_cache["previous_residual"] = processed_hidden_states - input_to_main_loop - # Apply final norm and reshape output_after_norm = module.norm_out(processed_hidden_states, temb) p = module.config.patch_size final_output_list = [] @@ -842,10 +853,12 @@ def _handle_lumina2_forward( image_part = output_after_norm[i][enc_len:seq_len_val] h_p, w_p = height // p, width // p reconstructed_image = ( - image_part.view(h_p, w_p, p, p, module.out_channels).permute(4, 0, 2, 1, 3).flatten(3, 4).flatten(1, 2) + image_part.view(h_p, w_p, p, p, module.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) ) final_output_list.append(reconstructed_image) - final_output_tensor = torch.stack(final_output_list, dim=0) if USE_PEFT_BACKEND: @@ -855,38 +868,21 @@ def _handle_lumina2_forward( return (final_output_tensor,) return Transformer2DModelOutput(sample=final_output_tensor) - def _handle_cogvideox_forward( - self, - module, - hidden_states, - encoder_hidden_states, - timestep, - timestep_cond=None, - ofs=None, - image_rotary_emb=None, - attention_kwargs=None, - return_dict=True, - ): - """ - Handle CogVideoX transformer forward pass with TeaCache. - Note: CogVideoX uses timestep embedding directly (not from a block) and caches - both encoder_hidden_states and hidden_states residuals. - - Args: - module: The CogVideoXTransformer3DModel instance. - hidden_states (`torch.Tensor`): Input latent tensor. - encoder_hidden_states (`torch.Tensor`): Text encoder outputs. - timestep (`torch.Tensor`): Current diffusion timestep. - timestep_cond (`torch.Tensor`, optional): Additional timestep conditioning. - ofs (`torch.Tensor`, optional): Offset tensor. - image_rotary_emb (`torch.Tensor`, optional): Rotary embeddings. - attention_kwargs (`dict`, optional): Additional attention arguments. - return_dict (`bool`): Whether to return a dict. - - Returns: - `torch.Tensor` or `Transformer2DModelOutput`: Denoised output. - """ +class _CogVideoXTeaCacheAdapter: + @staticmethod + def forward( + hook: TeaCacheHook, + module: torch.nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + ofs: Optional[Union[int, float, torch.LongTensor]] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers @@ -899,14 +895,12 @@ def _handle_cogvideox_forward( if USE_PEFT_BACKEND: scale_lora_layers(module, lora_scale) - state = self.state_manager.get_state() - self._maybe_reset_state_for_new_inference(state, module, reset_encoder_residual=True) + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module, reset_encoder_residual=True) - batch_size, num_frames, channels, height, width = hidden_states.shape + batch_size, num_frames, _, height, width = hidden_states.shape - # Process time embedding - timesteps = timestep - t_emb = module.time_proj(timesteps) + t_emb = module.time_proj(timestep) t_emb = t_emb.to(dtype=hidden_states.dtype) emb = module.time_embedding(t_emb, timestep_cond) @@ -916,80 +910,63 @@ def _handle_cogvideox_forward( ofs_emb = module.ofs_embedding(ofs_emb) emb = emb + ofs_emb - # Process patch embedding - hidden_states = module.patch_embed(encoder_hidden_states, hidden_states) - hidden_states = module.embedding_dropout(hidden_states) + hs = module.patch_embed(encoder_hidden_states, hidden_states) + hs = module.embedding_dropout(hs) text_seq_length = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] + enc = hs[:, :text_seq_length] + hs = hs[:, text_seq_length:] - # Extract modulated input (CogVideoX uses timestep embedding directly) - modulated_inp = self.extractor_fn(module, hidden_states, emb) + modulated_inp = hook.extractor_fn(module, hs, emb) - # Make caching decision - should_calc = self._should_compute_full_transformer(state, modulated_inp) - - # Fast or slow path based on caching decision - if not should_calc: - # Fast path: apply cached residuals (both encoder and hidden_states) - hidden_states = hidden_states + state.previous_residual - encoder_hidden_states = encoder_hidden_states + state.previous_residual_encoder - else: - # Slow path: full computation - ori_hidden_states = hidden_states.clone() - ori_encoder_hidden_states = encoder_hidden_states.clone() + def cached_fn(): + return hs + state.previous_residual, enc + state.previous_residual_encoder - # Process through transformer blocks + def compute_fn(): + hs0 = hs + enc0 = enc + ori_hs = hs0.clone() + ori_enc = enc0.clone() + hs1, enc1 = hs0, enc0 for block in module.transformer_blocks: if torch.is_grad_enabled() and module.gradient_checkpointing: ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - lambda *args: block(*args), - hidden_states, - encoder_hidden_states, + hs1, enc1 = torch.utils.checkpoint.checkpoint( + lambda *a: block(*a), + hs1, + enc1, emb, image_rotary_emb, **ckpt_kwargs, ) else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, + hs1, enc1 = block( + hidden_states=hs1, + encoder_hidden_states=enc1, temb=emb, image_rotary_emb=image_rotary_emb, ) + return hs1, enc1, ori_hs, ori_enc - # Cache both residuals - state.previous_residual = hidden_states - ori_hidden_states - state.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states + hs, enc = hook._run_dual_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) - state.previous_modulated_input = modulated_inp - state.cnt += 1 - - # Apply final norm if not module.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = module.norm_final(hidden_states) + hs = module.norm_final(hs) else: - # CogVideoX-5B and CogVideoX1.5-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = module.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] + hs_cat = torch.cat([enc, hs], dim=1) + hs_cat = module.norm_final(hs_cat) + hs = hs_cat[:, text_seq_length:] - # Final block - hidden_states = module.norm_out(hidden_states, temb=emb) - hidden_states = module.proj_out(hidden_states) + hs = module.norm_out(hs, temb=emb) + hs = module.proj_out(hs) - # Unpatchify p = module.config.patch_size p_t = module.config.patch_size_t - if p_t is None: - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = hs.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) else: - output = hidden_states.reshape( + output = hs.reshape( batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p ) output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) @@ -1001,67 +978,6 @@ def _handle_cogvideox_forward( return (output,) return Transformer2DModelOutput(sample=output) - @torch.compiler.disable - def _should_compute_full_transformer(self, state, modulated_inp): - """ - Determine whether to compute full transformer blocks or reuse cached residual. - - This method implements the core caching decision logic from the TeaCache paper: - - Always compute first and last timesteps (for maximum quality) - - For intermediate timesteps, compute relative L1 distance between current and previous modulated inputs - - Apply polynomial rescaling to convert distance to model-specific caching signal - - Accumulate rescaled distances and compare to threshold - - Return True (compute) if accumulated distance exceeds threshold, False (cache) otherwise - - Args: - state (`TeaCacheState`): Current state containing counters and cached values. - modulated_inp (`torch.Tensor`): Modulated input extracted using configured extractor function. - - Returns: - `bool`: True to compute full transformer, False to reuse cached residual. - """ - # Compute first timestep - if state.cnt == 0: - state.accumulated_rel_l1_distance = 0 - return True - - # compute last timestep (if num_steps is set) - if state.num_steps > 0 and state.cnt == state.num_steps - 1: - state.accumulated_rel_l1_distance = 0 - return True - - # Need previous modulated input for comparison - if state.previous_modulated_input is None: - return True - - # Compute relative L1 distance - rel_distance = ( - ( - (modulated_inp - state.previous_modulated_input).abs().mean() - / state.previous_modulated_input.abs().mean() - ) - .cpu() - .item() - ) - - # Apply polynomial rescaling - rescaled_distance = self.rescale_func(rel_distance) - state.accumulated_rel_l1_distance += rescaled_distance - - # Debug logging (uncomment to debug) - # logger.warning(f"Step {state.cnt}: rel_l1={rel_distance:.6f}, rescaled={rescaled_distance:.6f}, accumulated={state.accumulated_rel_l1_distance:.6f}, thresh={self.config.rel_l1_thresh}") - - # Make decision based on accumulated threshold - if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh: - return False - else: - state.accumulated_rel_l1_distance = 0 # Reset accumulator - return True - - def reset_state(self, module): - self.state_manager.reset() - return module - def apply_teacache(module, config: TeaCacheConfig): """ From c605f4dd4692aa3a560122b67c65af1a33545ef1 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Tue, 6 Jan 2026 09:49:08 +0530 Subject: [PATCH 16/25] refactor teacache: replace adapter classes with standalone functions Signed-off-by: Prajwal A --- src/diffusers/hooks/teacache.py | 1027 +++++++++++++++---------------- tests/hooks/test_teacache.py | 58 +- 2 files changed, 524 insertions(+), 561 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index f525fb343922..71ec937572d6 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -17,112 +17,73 @@ import torch -from ..utils import logging +from ..utils import get_logger from .hooks import BaseState, HookRegistry, ModelHook, StateManager -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +logger = get_logger(__name__) # pylint: disable=invalid-name _TEACACHE_HOOK = "teacache" -# Model-specific polynomial coefficients from TeaCache paper/reference implementations -_MODEL_COEFFICIENTS = { - "Flux": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01], - "FluxKontext": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02], - "Mochi": [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03], - "Lumina2": [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344], - "CogVideoX": [ - -1.53880483e03, - 8.43202495e02, - -1.34363087e02, - 7.97131516e00, - -5.23162339e-02, - ], # Default to 5b variant - # CogVideoX model variants with specific coefficients - "CogVideoX-2b": [-3.10658903e01, 2.54732368e01, -5.92380459e00, 1.75769064e00, -3.61568434e-03], - "CogVideoX-5b": [-1.53880483e03, 8.43202495e02, -1.34363087e02, 7.97131516e00, -5.23162339e-02], - "CogVideoX1.5-5B": [2.50210439e02, -1.65061612e02, 3.57804877e01, -7.81551492e-01, 3.58559703e-02], - "CogVideoX1.5-5B-I2V": [1.22842302e02, -1.04088754e02, 2.62981677e01, -3.06001e-01, 3.71213220e-02], -} - - -def _flux_modulated_input_extractor(module, hidden_states, timestep_emb): - """Extract modulated input for FLUX models.""" - return module.transformer_blocks[0].norm1(hidden_states, emb=timestep_emb)[0] - - -def _mochi_modulated_input_extractor(module, hidden_states, timestep_emb): - """Extract modulated input for Mochi models.""" - # Mochi norm1 returns tuple: (modulated_inp, gate_msa, scale_mlp, gate_mlp) - return module.transformer_blocks[0].norm1(hidden_states, timestep_emb)[0] - - -def _lumina2_modulated_input_extractor(module, hidden_states, timestep_emb): - """Extract modulated input for Lumina2 models.""" - # Lumina2 uses 'layers' instead of 'transformer_blocks' and norm1 returns tuple - # Note: This extractor expects input_to_main_loop as hidden_states (after preprocessing) - return module.layers[0].norm1(hidden_states, timestep_emb)[0] - - -def _cogvideox_modulated_input_extractor(module, hidden_states, timestep_emb): - """Extract modulated input for CogVideoX models.""" - # CogVideoX uses the timestep embedding directly, not from a block - return timestep_emb - - -# Extractor registry - maps model types to extraction functions -# Multiple model variants can share the same extractor -# Order matters: more specific variants first (e.g., CogVideoX1.5-5B-I2V before CogVideoX) -_EXTRACTOR_REGISTRY = { - "FluxKontext": _flux_modulated_input_extractor, - "Flux": _flux_modulated_input_extractor, - "Mochi": _mochi_modulated_input_extractor, - "Lumina2": _lumina2_modulated_input_extractor, - "CogVideoX1.5-5B-I2V": _cogvideox_modulated_input_extractor, - "CogVideoX1.5-5B": _cogvideox_modulated_input_extractor, - "CogVideoX-2b": _cogvideox_modulated_input_extractor, - "CogVideoX-5b": _cogvideox_modulated_input_extractor, - "CogVideoX": _cogvideox_modulated_input_extractor, -} - - -def _auto_detect_extractor(module): - """Auto-detect and return appropriate extractor.""" - return _EXTRACTOR_REGISTRY[_auto_detect_model_type(module)] + +def _get_model_config(): + """Get model configuration mapping. + + Returns dict at runtime when forward functions are defined. + Order matters: more specific model variants must come before generic ones. + """ + return { + "FluxKontext": { + "forward_func": _flux_teacache_forward, + "coefficients": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02], + }, + "Flux": { + "forward_func": _flux_teacache_forward, + "coefficients": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01], + }, + "Mochi": { + "forward_func": _mochi_teacache_forward, + "coefficients": [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03], + }, + "Lumina2": { + "forward_func": _lumina2_teacache_forward, + "coefficients": [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344], + }, + "CogVideoX1.5-5B-I2V": { + "forward_func": _cogvideox_teacache_forward, + "coefficients": [1.22842302e02, -1.04088754e02, 2.62981677e01, -3.06001e-01, 3.71213220e-02], + }, + "CogVideoX1.5-5B": { + "forward_func": _cogvideox_teacache_forward, + "coefficients": [2.50210439e02, -1.65061612e02, 3.57804877e01, -7.81551492e-01, 3.58559703e-02], + }, + "CogVideoX-2b": { + "forward_func": _cogvideox_teacache_forward, + "coefficients": [-3.10658903e01, 2.54732368e01, -5.92380459e00, 1.75769064e00, -3.61568434e-03], + }, + "CogVideoX-5b": { + "forward_func": _cogvideox_teacache_forward, + "coefficients": [-1.53880483e03, 8.43202495e02, -1.34363087e02, 7.97131516e00, -5.23162339e-02], + }, + "CogVideoX": { + "forward_func": _cogvideox_teacache_forward, + "coefficients": [-1.53880483e03, 8.43202495e02, -1.34363087e02, 7.97131516e00, -5.23162339e-02], + }, + } def _auto_detect_model_type(module): """Auto-detect model type from class name and config path.""" class_name = module.__class__.__name__ config_path = getattr(getattr(module, "config", None), "_name_or_path", "").lower() - + model_config = _get_model_config() + # Check config path first (for variants), then class name (ordered most specific first) - for model_type in _EXTRACTOR_REGISTRY: + for model_type in model_config: if model_type.lower() in config_path or model_type in class_name: - if model_type not in _MODEL_COEFFICIENTS: - raise ValueError(f"TeaCache: No coefficients for '{model_type}'") return model_type - - raise ValueError(f"TeaCache: Unsupported model '{class_name}'. Supported: {', '.join(_EXTRACTOR_REGISTRY)}") - -def _get_model_coefficients(model_type): - """Get polynomial coefficients for a specific model type. - - Args: - model_type: Model type string (e.g., "Flux", "Mochi") - - Raises: - ValueError: If coefficients not found for model type. - """ - if model_type not in _MODEL_COEFFICIENTS: - available_models = ", ".join(_MODEL_COEFFICIENTS.keys()) - raise ValueError( - f"TeaCache: No coefficients found for model type '{model_type}'. " - f"Available models: {available_models}. " - f"Please provide coefficients explicitly in TeaCacheConfig." - ) - return _MODEL_COEFFICIENTS[model_type] + raise ValueError(f"TeaCache: Unsupported model '{class_name}'. Supported: {', '.join(model_config.keys())}") @dataclass @@ -134,7 +95,7 @@ class TeaCacheConfig: by reusing transformer block computations when consecutive timestep embeddings are similar. It uses polynomial rescaling of L1 distances between modulated inputs to intelligently decide when to cache. - Currently supports: FLUX, FLUX-Kontext, Mochi, Lumina2, and CogVideoX models. Model type is auto-detected, and + Currently supports: FLUX, FLUX-Kontext, Mochi, Lumina2, and CogVideoX models. Model type is auto-detected, and model-specific polynomial coefficients are automatically applied. Args: @@ -167,7 +128,7 @@ class TeaCacheConfig: Callback function that returns the total number of inference steps. Alternative to `num_inference_steps` for dynamic step counts. - Examples: + Example: ```python from diffusers import FluxPipeline from diffusers.hooks import TeaCacheConfig @@ -337,8 +298,8 @@ class TeaCacheHook(ModelHook): Attributes: config (TeaCacheConfig): Configuration containing threshold, polynomial coefficients, and optional callbacks. - rescale_func (np.poly1d): - Polynomial function for rescaling L1 distances using model-specific coefficients. + coefficients (List[float]): + Polynomial coefficients for rescaling L1 distances (auto-detected or user-provided). state_manager (StateManager): Manages TeaCacheState across forward passes, maintaining counters and cached values. """ @@ -348,30 +309,27 @@ class TeaCacheHook(ModelHook): def __init__(self, config: TeaCacheConfig): super().__init__() self.config = config - # Set default rescale_func with config coefficients (will be updated in initialize_hook if needed) - # This ensures rescale_func is always valid, even if initialize_hook isn't called (e.g., in tests) - default_coeffs = config.coefficients if config.coefficients else _MODEL_COEFFICIENTS["Flux"] - self.coefficients = default_coeffs - self.rescale_func = self._create_rescale_func(default_coeffs) + # Default coefficients (will be updated in initialize_hook if needed) + # This ensures coefficients are always valid, even if initialize_hook isn't called (e.g., in tests) + self.coefficients = ( + config.coefficients + if config.coefficients + else [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01] + ) self.state_manager = StateManager(TeaCacheState, (), {}) - self.extractor_fn = None self.model_type = None + self._forward_func = None - @staticmethod - def _create_rescale_func(coefficients): - """Create polynomial rescale function from coefficients. - - Evaluates: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4] - """ - def rescale(x): - return coefficients[0] * x**4 + coefficients[1] * x**3 + coefficients[2] * x**2 + coefficients[3] * x + coefficients[4] - return rescale + def _rescale_distance(self, x: float) -> float: + """Evaluate polynomial: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]""" + c = self.coefficients + return c[0] * x**4 + c[1] * x**3 + c[2] * x**2 + c[3] * x + c[4] def _maybe_reset_state_for_new_inference(self, state, module, reset_encoder_residual=False): """Reset state if we've completed all steps (start of new inference run). - + Also initializes num_steps on first timestep if not set. - + Args: state: TeaCacheState instance. module: The transformer module. @@ -396,24 +354,30 @@ def _maybe_reset_state_for_new_inference(self, state, module, reset_encoder_resi state.num_steps = self.config.num_inference_steps_callback() elif hasattr(module, "num_steps"): state.num_steps = module.num_steps - + if state.num_steps > 0: logger.info(f"TeaCache: Using {state.num_steps} inference steps") def initialize_hook(self, module): + # TODO: DN6 raised concern about context setting timing. + # Currently set in initialize_hook(). Should this be in denoising loop instead? + # See PR #12652 for discussion. Keeping current behavior pending clarification. self.state_manager.set_context("teacache") - # Strict auto-detection + model_config = _get_model_config() + + # Auto-detect model type and get forward function if self.config.extract_modulated_input_fn is None: - self.extractor_fn = _auto_detect_extractor(module) # Raises if unsupported - self.model_type = _auto_detect_model_type(module) # Raises if unsupported + self.model_type = _auto_detect_model_type(module) # Raises if unsupported + self._forward_func = model_config[self.model_type]["forward_func"] else: - self.extractor_fn = self.config.extract_modulated_input_fn - # Still try to detect model type for coefficients + # User provided custom extractor - try to detect model type for coefficients try: self.model_type = _auto_detect_model_type(module) + self._forward_func = model_config[self.model_type]["forward_func"] except ValueError: - self.model_type = None # User provided custom extractor + self.model_type = None + self._forward_func = None # Will fall back to class name matching in new_forward logger.warning( f"TeaCache: Using custom extractor for {module.__class__.__name__}. " f"Coefficients must be provided explicitly." @@ -426,14 +390,11 @@ def initialize_hook(self, module): "TeaCache: Cannot auto-detect coefficients when using custom extractor. " "Please provide coefficients explicitly in TeaCacheConfig." ) - self.coefficients = _get_model_coefficients(self.model_type) # Raises if not found + self.coefficients = model_config[self.model_type]["coefficients"] logger.info(f"TeaCache: Using {self.model_type} coefficients") else: self.coefficients = self.config.coefficients - logger.info(f"TeaCache: Using user-provided coefficients") - - # Initialize rescale function with final coefficients - self.rescale_func = self._create_rescale_func(self.coefficients) + logger.info("TeaCache: Using user-provided coefficients") return module @@ -497,15 +458,20 @@ def _run_dual_residual_cache( return hs_y, enc_y def new_forward(self, module, *args, **kwargs): + # Use stored forward function if available (set during initialize_hook) + if self._forward_func is not None: + return self._forward_func(self, module, *args, **kwargs) + + # Fallback to class name matching for backwards compatibility module_class_name = module.__class__.__name__ if "Flux" in module_class_name: - return _FluxTeaCacheAdapter.forward(self, module, *args, **kwargs) + return _flux_teacache_forward(self, module, *args, **kwargs) if "Mochi" in module_class_name: - return _MochiTeaCacheAdapter.forward(self, module, *args, **kwargs) + return _mochi_teacache_forward(self, module, *args, **kwargs) if "Lumina2" in module_class_name: - return _Lumina2TeaCacheAdapter.forward(self, module, *args, **kwargs) + return _lumina2_teacache_forward(self, module, *args, **kwargs) if "CogVideoX" in module_class_name: - return _CogVideoXTeaCacheAdapter.forward(self, module, *args, **kwargs) + return _cogvideox_teacache_forward(self, module, *args, **kwargs) logger.warning(f"TeaCache: Unsupported model type {module_class_name}.") return self.fn_ref.original_forward(*args, **kwargs) @@ -541,7 +507,7 @@ def _should_compute_full_transformer(self, state, modulated_inp): .cpu() .item() ) - rescaled_distance = self.rescale_func(rel_distance) + rescaled_distance = self._rescale_distance(rel_distance) state.accumulated_rel_l1_distance += rescaled_distance if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh: @@ -554,429 +520,424 @@ def reset_state(self, module): self.state_manager.reset() return module -class _FluxTeaCacheAdapter: - @staticmethod - def forward( - hook: TeaCacheHook, - module: torch.nn.Module, - hidden_states: torch.Tensor, - timestep: torch.Tensor, - pooled_projections: torch.Tensor, - encoder_hidden_states: torch.Tensor, - txt_ids: torch.Tensor, - img_ids: torch.Tensor, - return_dict: bool = True, - **kwargs, - ): - from diffusers.models.modeling_outputs import Transformer2DModelOutput - - state = hook.state_manager.get_state() - hook._maybe_reset_state_for_new_inference(state, module) - - hidden_states = module.x_embedder(hidden_states) - - timestep_scaled = timestep.to(hidden_states.dtype) * 1000 - if kwargs.get("guidance") is not None: - guidance = kwargs["guidance"].to(hidden_states.dtype) * 1000 - temb = module.time_text_embed(timestep_scaled, guidance, pooled_projections) - else: - temb = module.time_text_embed(timestep_scaled, pooled_projections) - - modulated_inp = hook.extractor_fn(module, hidden_states, temb) - joint_attention_kwargs = kwargs.get("joint_attention_kwargs") - def cached_fn(): - return hidden_states + state.previous_residual - - def compute_fn(): - hs = hidden_states - ori_hs = hs.clone() - - enc = module.context_embedder(encoder_hidden_states) - - if txt_ids.ndim == 3: - txt = txt_ids[0] - else: - txt = txt_ids - if img_ids.ndim == 3: - img = img_ids[0] - else: - img = img_ids - - ids = torch.cat((txt, img), dim=0) - image_rotary_emb = module.pos_embed(ids) - - for block in module.transformer_blocks: - enc, hs = block( - hidden_states=hs, - encoder_hidden_states=enc, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, +def _flux_teacache_forward( + hook: "TeaCacheHook", + module: torch.nn.Module, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + pooled_projections: torch.Tensor, + encoder_hidden_states: torch.Tensor, + txt_ids: torch.Tensor, + img_ids: torch.Tensor, + return_dict: bool = True, + **kwargs, +): + """TeaCache forward for Flux models.""" + from diffusers.models.modeling_outputs import Transformer2DModelOutput + + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module) + + hidden_states = module.x_embedder(hidden_states) + + timestep_scaled = timestep.to(hidden_states.dtype) * 1000 + if kwargs.get("guidance") is not None: + guidance = kwargs["guidance"].to(hidden_states.dtype) * 1000 + temb = module.time_text_embed(timestep_scaled, guidance, pooled_projections) + else: + temb = module.time_text_embed(timestep_scaled, pooled_projections) + + # Inline extractor: Flux uses transformer_blocks[0].norm1 + modulated_inp = module.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] + joint_attention_kwargs = kwargs.get("joint_attention_kwargs") + + def cached_fn(): + return hidden_states + state.previous_residual + + def compute_fn(): + hs = hidden_states + ori_hs = hs.clone() + + enc = module.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + txt = txt_ids[0] + else: + txt = txt_ids + if img_ids.ndim == 3: + img = img_ids[0] + else: + img = img_ids + + ids = torch.cat((txt, img), dim=0) + image_rotary_emb = module.pos_embed(ids) + + for block in module.transformer_blocks: + enc, hs = block( + hidden_states=hs, + encoder_hidden_states=enc, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + for block in module.single_transformer_blocks: + enc, hs = block( + hidden_states=hs, + encoder_hidden_states=enc, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + return hs, ori_hs + + hidden_states = hook._run_single_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) + + hidden_states = module.norm_out(hidden_states, temb) + output = module.proj_out(hidden_states) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def _mochi_teacache_forward( + hook: "TeaCacheHook", + module: torch.nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_attention_mask: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +): + """TeaCache forward for Mochi models.""" + from diffusers.models.modeling_outputs import Transformer2DModelOutput + from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(module, lora_scale) + + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module) + + batch_size, _, num_frames, height, width = hidden_states.shape + p = module.config.patch_size + post_patch_height = height // p + post_patch_width = width // p + + temb, encoder_hidden_states = module.time_embed( + timestep, + encoder_hidden_states, + encoder_attention_mask, + hidden_dtype=hidden_states.dtype, + ) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = module.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + image_rotary_emb = module.rope( + module.pos_frequencies, + num_frames, + post_patch_height, + post_patch_width, + device=hidden_states.device, + dtype=torch.float32, + ) + + # Inline extractor: Mochi norm1 returns tuple (modulated_inp, gate_msa, scale_mlp, gate_mlp) + modulated_inp = module.transformer_blocks[0].norm1(hidden_states, temb)[0] + + def cached_fn(): + return hidden_states + state.previous_residual + + def compute_fn(): + hs = hidden_states + ori_hs = hs.clone() + enc = encoder_hidden_states + for block in module.transformer_blocks: + if torch.is_grad_enabled() and module.gradient_checkpointing: + hs, enc = module._gradient_checkpointing_func( + block, + hs, + enc, + temb, + encoder_attention_mask, + image_rotary_emb, ) - for block in module.single_transformer_blocks: - enc, hs = block( + else: + hs, enc = block( hidden_states=hs, encoder_hidden_states=enc, temb=temb, + encoder_attention_mask=encoder_attention_mask, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, ) - return hs, ori_hs - - hidden_states = hook._run_single_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) - - hidden_states = module.norm_out(hidden_states, temb) - output = module.proj_out(hidden_states) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) - - -class _MochiTeaCacheAdapter: - @staticmethod - def forward( - hook: TeaCacheHook, - module: torch.nn.Module, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: torch.Tensor, - encoder_attention_mask: torch.Tensor, - attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ): - from diffusers.models.modeling_outputs import Transformer2DModelOutput - from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers - - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - scale_lora_layers(module, lora_scale) - - state = hook.state_manager.get_state() - hook._maybe_reset_state_for_new_inference(state, module) - - batch_size, _, num_frames, height, width = hidden_states.shape - p = module.config.patch_size - post_patch_height = height // p - post_patch_width = width // p - - temb, encoder_hidden_states = module.time_embed( - timestep, - encoder_hidden_states, - encoder_attention_mask, - hidden_dtype=hidden_states.dtype, + hs = module.norm_out(hs, temb) + return hs, ori_hs + + hidden_states = hook._run_single_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) + + hidden_states = module.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) + hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if USE_PEFT_BACKEND: + unscale_lora_layers(module, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def _lumina2_teacache_forward( + hook: "TeaCacheHook", + module: torch.nn.Module, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +): + """TeaCache forward for Lumina2 models (handles variable seq lens + per-len caches).""" + from diffusers.models.modeling_outputs import Transformer2DModelOutput + from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(module, lora_scale) + + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module) + + batch_size, _, height, width = hidden_states.shape + + temb, encoder_hidden_states_processed = module.time_caption_embed(hidden_states, timestep, encoder_hidden_states) + ( + image_patch_embeddings, + context_rotary_emb, + noise_rotary_emb, + joint_rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = module.rope_embedder(hidden_states, encoder_attention_mask) + image_patch_embeddings = module.x_embedder(image_patch_embeddings) + + for layer in module.context_refiner: + encoder_hidden_states_processed = layer( + encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb ) - - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) - hidden_states = module.patch_embed(hidden_states) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) - - image_rotary_emb = module.rope( - module.pos_frequencies, - num_frames, - post_patch_height, - post_patch_width, - device=hidden_states.device, - dtype=torch.float32, - ) - - modulated_inp = hook.extractor_fn(module, hidden_states, temb) - - def cached_fn(): - return hidden_states + state.previous_residual - - def compute_fn(): - hs = hidden_states - ori_hs = hs.clone() - enc = encoder_hidden_states - for block in module.transformer_blocks: - if torch.is_grad_enabled() and module.gradient_checkpointing: - hs, enc = module._gradient_checkpointing_func( - block, - hs, - enc, - temb, - encoder_attention_mask, - image_rotary_emb, - ) - else: - hs, enc = block( - hidden_states=hs, - encoder_hidden_states=enc, - temb=temb, - encoder_attention_mask=encoder_attention_mask, - image_rotary_emb=image_rotary_emb, - ) - hs = module.norm_out(hs, temb) - return hs, ori_hs - - hidden_states = hook._run_single_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) - - hidden_states = module.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) - hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) - output = hidden_states.reshape(batch_size, -1, num_frames, height, width) - - if USE_PEFT_BACKEND: - unscale_lora_layers(module, lora_scale) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) - - -class _Lumina2TeaCacheAdapter: - @staticmethod - def forward( - hook: TeaCacheHook, - module: torch.nn.Module, - hidden_states: torch.Tensor, - timestep: torch.Tensor, - encoder_hidden_states: torch.Tensor, - encoder_attention_mask: torch.Tensor, - attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ): - # Keep Lumina2 logic isolated (variable seq lens + per-len caches). - from diffusers.models.modeling_outputs import Transformer2DModelOutput - from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers - - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - scale_lora_layers(module, lora_scale) - - state = hook.state_manager.get_state() - hook._maybe_reset_state_for_new_inference(state, module) - - batch_size, _, height, width = hidden_states.shape - - temb, encoder_hidden_states_processed = module.time_caption_embed(hidden_states, timestep, encoder_hidden_states) - ( - image_patch_embeddings, - context_rotary_emb, - noise_rotary_emb, - joint_rotary_emb, - encoder_seq_lengths, - seq_lengths, - ) = module.rope_embedder(hidden_states, encoder_attention_mask) - image_patch_embeddings = module.x_embedder(image_patch_embeddings) - - for layer in module.context_refiner: - encoder_hidden_states_processed = layer( - encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb - ) - for layer in module.noise_refiner: - image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb) - - max_seq_len = max(seq_lengths) - input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, module.config.hidden_size) + for layer in module.noise_refiner: + image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb) + + max_seq_len = max(seq_lengths) + input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, module.config.hidden_size) + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len] + input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i] + + use_mask = len(set(seq_lengths)) > 1 + attention_mask_for_main_loop_arg = None + if use_mask: + mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool) for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): - input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len] - input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i] - - use_mask = len(set(seq_lengths)) > 1 - attention_mask_for_main_loop_arg = None - if use_mask: - mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool) - for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): - mask[i, :seq_len_val] = True - attention_mask_for_main_loop_arg = mask - - modulated_inp = hook.extractor_fn(module, input_to_main_loop, temb) - - cache_key = max_seq_len - if cache_key not in state.cache_dict: - state.cache_dict[cache_key] = { - "previous_modulated_input": None, - "previous_residual": None, - "accumulated_rel_l1_distance": 0.0, - } - current_cache = state.cache_dict[cache_key] - - if state.cnt == 0 or state.cnt == state.num_steps - 1: - should_calc = True - current_cache["accumulated_rel_l1_distance"] = 0.0 - else: - if current_cache["previous_modulated_input"] is not None: - prev_mod_input = current_cache["previous_modulated_input"] - prev_mean = prev_mod_input.abs().mean() - if prev_mean.item() > 1e-9: - rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item() - else: - rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float("inf") - rescaled_distance = hook.rescale_func(rel_l1_change) - current_cache["accumulated_rel_l1_distance"] += rescaled_distance - if current_cache["accumulated_rel_l1_distance"] < hook.config.rel_l1_thresh: - should_calc = False - else: - should_calc = True - current_cache["accumulated_rel_l1_distance"] = 0.0 + mask[i, :seq_len_val] = True + attention_mask_for_main_loop_arg = mask + + # Inline extractor: Lumina2 uses layers[0].norm1 + modulated_inp = module.layers[0].norm1(input_to_main_loop, temb)[0] + + cache_key = max_seq_len + if cache_key not in state.cache_dict: + state.cache_dict[cache_key] = { + "previous_modulated_input": None, + "previous_residual": None, + "accumulated_rel_l1_distance": 0.0, + } + current_cache = state.cache_dict[cache_key] + + if state.cnt == 0 or state.cnt == state.num_steps - 1: + should_calc = True + current_cache["accumulated_rel_l1_distance"] = 0.0 + else: + if current_cache["previous_modulated_input"] is not None: + prev_mod_input = current_cache["previous_modulated_input"] + prev_mean = prev_mod_input.abs().mean() + if prev_mean.item() > 1e-9: + rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item() + else: + rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float("inf") + rescaled_distance = hook._rescale_distance(rel_l1_change) + current_cache["accumulated_rel_l1_distance"] += rescaled_distance + if current_cache["accumulated_rel_l1_distance"] < hook.config.rel_l1_thresh: + should_calc = False else: should_calc = True current_cache["accumulated_rel_l1_distance"] = 0.0 + else: + should_calc = True + current_cache["accumulated_rel_l1_distance"] = 0.0 - current_cache["previous_modulated_input"] = modulated_inp.clone() + current_cache["previous_modulated_input"] = modulated_inp.clone() - if state.uncond_seq_len is None: - state.uncond_seq_len = cache_key - if cache_key != state.uncond_seq_len: - state.cnt += 1 - if state.cnt >= state.num_steps: - state.cnt = 0 + if state.uncond_seq_len is None: + state.uncond_seq_len = cache_key + if cache_key != state.uncond_seq_len: + state.cnt += 1 + if state.cnt >= state.num_steps: + state.cnt = 0 - if not should_calc and current_cache["previous_residual"] is not None: - processed_hidden_states = input_to_main_loop + current_cache["previous_residual"] - else: - current_processing_states = input_to_main_loop - for layer in module.layers: - current_processing_states = layer( - current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb + if not should_calc and current_cache["previous_residual"] is not None: + processed_hidden_states = input_to_main_loop + current_cache["previous_residual"] + else: + current_processing_states = input_to_main_loop + for layer in module.layers: + current_processing_states = layer( + current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb + ) + processed_hidden_states = current_processing_states + current_cache["previous_residual"] = processed_hidden_states - input_to_main_loop + + output_after_norm = module.norm_out(processed_hidden_states, temb) + p = module.config.patch_size + final_output_list = [] + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + image_part = output_after_norm[i][enc_len:seq_len_val] + h_p, w_p = height // p, width // p + reconstructed_image = ( + image_part.view(h_p, w_p, p, p, module.out_channels).permute(4, 0, 2, 1, 3).flatten(3, 4).flatten(1, 2) + ) + final_output_list.append(reconstructed_image) + final_output_tensor = torch.stack(final_output_list, dim=0) + + if USE_PEFT_BACKEND: + unscale_lora_layers(module, lora_scale) + + if not return_dict: + return (final_output_tensor,) + return Transformer2DModelOutput(sample=final_output_tensor) + + +def _cogvideox_teacache_forward( + hook: "TeaCacheHook", + module: torch.nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + ofs: Optional[Union[int, float, torch.LongTensor]] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +): + """TeaCache forward for CogVideoX models (handles dual residual caching).""" + from diffusers.models.modeling_outputs import Transformer2DModelOutput + from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(module, lora_scale) + + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module, reset_encoder_residual=True) + + batch_size, num_frames, _, height, width = hidden_states.shape + + t_emb = module.time_proj(timestep) + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = module.time_embedding(t_emb, timestep_cond) + + if module.ofs_embedding is not None: + ofs_emb = module.ofs_proj(ofs) + ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) + ofs_emb = module.ofs_embedding(ofs_emb) + emb = emb + ofs_emb + + hs = module.patch_embed(encoder_hidden_states, hidden_states) + hs = module.embedding_dropout(hs) + + text_seq_length = encoder_hidden_states.shape[1] + enc = hs[:, :text_seq_length] + hs = hs[:, text_seq_length:] + + # Inline extractor: CogVideoX uses timestep embedding directly + modulated_inp = emb + + def cached_fn(): + return hs + state.previous_residual, enc + state.previous_residual_encoder + + def compute_fn(): + hs0 = hs + enc0 = enc + ori_hs = hs0.clone() + ori_enc = enc0.clone() + hs1, enc1 = hs0, enc0 + for block in module.transformer_blocks: + if torch.is_grad_enabled() and module.gradient_checkpointing: + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hs1, enc1 = torch.utils.checkpoint.checkpoint( + lambda *a: block(*a), + hs1, + enc1, + emb, + image_rotary_emb, + **ckpt_kwargs, ) - processed_hidden_states = current_processing_states - current_cache["previous_residual"] = processed_hidden_states - input_to_main_loop + else: + hs1, enc1 = block( + hidden_states=hs1, + encoder_hidden_states=enc1, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + return hs1, enc1, ori_hs, ori_enc - output_after_norm = module.norm_out(processed_hidden_states, temb) - p = module.config.patch_size - final_output_list = [] - for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): - image_part = output_after_norm[i][enc_len:seq_len_val] - h_p, w_p = height // p, width // p - reconstructed_image = ( - image_part.view(h_p, w_p, p, p, module.out_channels) - .permute(4, 0, 2, 1, 3) - .flatten(3, 4) - .flatten(1, 2) - ) - final_output_list.append(reconstructed_image) - final_output_tensor = torch.stack(final_output_list, dim=0) - - if USE_PEFT_BACKEND: - unscale_lora_layers(module, lora_scale) - - if not return_dict: - return (final_output_tensor,) - return Transformer2DModelOutput(sample=final_output_tensor) - - -class _CogVideoXTeaCacheAdapter: - @staticmethod - def forward( - hook: TeaCacheHook, - module: torch.nn.Module, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: Union[int, float, torch.LongTensor], - timestep_cond: Optional[torch.Tensor] = None, - ofs: Optional[Union[int, float, torch.LongTensor]] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ): - from diffusers.models.modeling_outputs import Transformer2DModelOutput - from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers - - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - scale_lora_layers(module, lora_scale) - - state = hook.state_manager.get_state() - hook._maybe_reset_state_for_new_inference(state, module, reset_encoder_residual=True) - - batch_size, num_frames, _, height, width = hidden_states.shape - - t_emb = module.time_proj(timestep) - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = module.time_embedding(t_emb, timestep_cond) - - if module.ofs_embedding is not None: - ofs_emb = module.ofs_proj(ofs) - ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) - ofs_emb = module.ofs_embedding(ofs_emb) - emb = emb + ofs_emb - - hs = module.patch_embed(encoder_hidden_states, hidden_states) - hs = module.embedding_dropout(hs) - - text_seq_length = encoder_hidden_states.shape[1] - enc = hs[:, :text_seq_length] - hs = hs[:, text_seq_length:] - - modulated_inp = hook.extractor_fn(module, hs, emb) - - def cached_fn(): - return hs + state.previous_residual, enc + state.previous_residual_encoder - - def compute_fn(): - hs0 = hs - enc0 = enc - ori_hs = hs0.clone() - ori_enc = enc0.clone() - hs1, enc1 = hs0, enc0 - for block in module.transformer_blocks: - if torch.is_grad_enabled() and module.gradient_checkpointing: - ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hs1, enc1 = torch.utils.checkpoint.checkpoint( - lambda *a: block(*a), - hs1, - enc1, - emb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hs1, enc1 = block( - hidden_states=hs1, - encoder_hidden_states=enc1, - temb=emb, - image_rotary_emb=image_rotary_emb, - ) - return hs1, enc1, ori_hs, ori_enc - - hs, enc = hook._run_dual_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) - - if not module.config.use_rotary_positional_embeddings: - hs = module.norm_final(hs) - else: - hs_cat = torch.cat([enc, hs], dim=1) - hs_cat = module.norm_final(hs_cat) - hs = hs_cat[:, text_seq_length:] - - hs = module.norm_out(hs, temb=emb) - hs = module.proj_out(hs) - - p = module.config.patch_size - p_t = module.config.patch_size_t - if p_t is None: - output = hs.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - else: - output = hs.reshape( - batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p - ) - output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + hs, enc = hook._run_dual_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) + + if not module.config.use_rotary_positional_embeddings: + hs = module.norm_final(hs) + else: + hs_cat = torch.cat([enc, hs], dim=1) + hs_cat = module.norm_final(hs_cat) + hs = hs_cat[:, text_seq_length:] + + hs = module.norm_out(hs, temb=emb) + hs = module.proj_out(hs) + + p = module.config.patch_size + p_t = module.config.patch_size_t + if p_t is None: + output = hs.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hs.reshape(batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) - if USE_PEFT_BACKEND: - unscale_lora_layers(module, lora_scale) + if USE_PEFT_BACKEND: + unscale_lora_layers(module, lora_scale) - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) def apply_teacache(module, config: TeaCacheConfig): @@ -991,7 +952,7 @@ def apply_teacache(module, config: TeaCacheConfig): module: The transformer model to optimize (e.g., FluxTransformer2DModel, CogVideoXTransformer3DModel). config (`TeaCacheConfig`): Configuration specifying caching threshold and optional callbacks. - Examples: + Example: ```python from diffusers import FluxPipeline from diffusers.hooks import TeaCacheConfig, apply_teacache diff --git a/tests/hooks/test_teacache.py b/tests/hooks/test_teacache.py index fd2ce6cfdf94..333d4dbd5823 100644 --- a/tests/hooks/test_teacache.py +++ b/tests/hooks/test_teacache.py @@ -139,7 +139,7 @@ def test_hook_initialization(self): hook = TeaCacheHook(config) self.assertEqual(hook.config.rel_l1_thresh, 0.2) - self.assertIsNotNone(hook.rescale_func) + self.assertIsNotNone(hook.coefficients) self.assertIsNotNone(hook.state_manager) def test_should_compute_full_transformer_logic(self): @@ -204,24 +204,26 @@ class TeaCacheMultiModelTests(unittest.TestCase): def test_model_coefficient_registry(self): """Test that model coefficients are properly registered.""" - from diffusers.hooks.teacache import _MODEL_COEFFICIENTS + from diffusers.hooks.teacache import _get_model_config - self.assertIn("Flux", _MODEL_COEFFICIENTS) - self.assertIn("Mochi", _MODEL_COEFFICIENTS) - self.assertIn("Lumina2", _MODEL_COEFFICIENTS) - self.assertIn("CogVideoX", _MODEL_COEFFICIENTS) + model_config = _get_model_config() + + self.assertIn("Flux", model_config) + self.assertIn("Mochi", model_config) + self.assertIn("Lumina2", model_config) + self.assertIn("CogVideoX", model_config) # Verify all coefficients are 5-element lists - for model_name, coeffs in _MODEL_COEFFICIENTS.items(): + for model_name, config in model_config.items(): + coeffs = config["coefficients"] self.assertEqual(len(coeffs), 5, f"{model_name} coefficients should have 5 elements") self.assertTrue( all(isinstance(c, (int, float)) for c in coeffs), f"{model_name} coefficients should be numbers" ) def test_mochi_extractor(self): - """Test Mochi modulated input extractor.""" + """Test Mochi modulated input extraction (now inlined in forward function).""" from diffusers import MochiTransformer3DModel - from diffusers.hooks.teacache import _mochi_modulated_input_extractor # Create a minimal Mochi model for testing model = MochiTransformer3DModel( @@ -247,15 +249,14 @@ def test_mochi_extractor(self): hidden_states = model.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (2, -1)).flatten(1, 2) - # Test extractor - modulated_inp = _mochi_modulated_input_extractor(model, hidden_states, temb) + # Test inline extraction logic: Mochi norm1 returns tuple (modulated_inp, ...) + modulated_inp = model.transformer_blocks[0].norm1(hidden_states, temb)[0] self.assertIsInstance(modulated_inp, torch.Tensor) self.assertEqual(modulated_inp.shape[0], hidden_states.shape[0]) def test_lumina2_extractor(self): - """Test Lumina2 modulated input extractor with simplified setup.""" + """Test Lumina2 modulated input extraction (now inlined in forward function).""" from diffusers import Lumina2Transformer2DModel - from diffusers.hooks.teacache import _lumina2_modulated_input_extractor # Create a minimal Lumina2 model for testing model = Lumina2Transformer2DModel( @@ -279,15 +280,14 @@ def test_lumina2_extractor(self): input_to_main_loop = torch.randn(batch_size, seq_len, hidden_size) temb = torch.randn(batch_size, hidden_size) - # Test extractor - modulated_inp = _lumina2_modulated_input_extractor(model, input_to_main_loop, temb) + # Test inline extraction logic: Lumina2 uses layers[0].norm1 + modulated_inp = model.layers[0].norm1(input_to_main_loop, temb)[0] self.assertIsInstance(modulated_inp, torch.Tensor) self.assertEqual(modulated_inp.shape[0], batch_size) def test_cogvideox_extractor(self): - """Test CogVideoX modulated input extractor.""" + """Test CogVideoX modulated input extraction (now inlined in forward function).""" from diffusers import CogVideoXTransformer3DModel - from diffusers.hooks.teacache import _cogvideox_modulated_input_extractor # Create a minimal CogVideoX model for testing model = CogVideoXTransformer3DModel( @@ -307,8 +307,8 @@ def test_cogvideox_extractor(self): t_emb = t_emb.to(dtype=hidden_states.dtype) emb = model.time_embedding(t_emb, None) - # Test extractor (should return emb directly) - modulated_inp = _cogvideox_modulated_input_extractor(model, hidden_states, emb) + # Test inline extraction logic: CogVideoX uses timestep embedding directly + modulated_inp = emb self.assertIsInstance(modulated_inp, torch.Tensor) self.assertEqual(modulated_inp.shape, emb.shape) @@ -316,7 +316,7 @@ def test_auto_detect_mochi(self): """Test auto-detection for Mochi models.""" from diffusers import MochiTransformer3DModel from diffusers.hooks import TeaCacheConfig, apply_teacache - from diffusers.hooks.teacache import _MODEL_COEFFICIENTS, _auto_detect_extractor + from diffusers.hooks.teacache import _get_model_config model = MochiTransformer3DModel( patch_size=2, @@ -328,9 +328,7 @@ def test_auto_detect_mochi(self): time_embed_dim=4, ) - # Test extractor detection - extractor = _auto_detect_extractor(model) - self.assertIsNotNone(extractor) + model_config = _get_model_config() # Test coefficient auto-detection config = TeaCacheConfig(rel_l1_thresh=0.2) @@ -340,7 +338,7 @@ def test_auto_detect_mochi(self): hook = registry.get_hook("teacache") self.assertIsNotNone(hook) # Verify coefficients were auto-set - self.assertEqual(hook.coefficients, _MODEL_COEFFICIENTS["Mochi"]) + self.assertEqual(hook.coefficients, model_config["Mochi"]["coefficients"]) model.disable_cache() @@ -348,7 +346,7 @@ def test_auto_detect_lumina2(self): """Test auto-detection for Lumina2 models.""" from diffusers import Lumina2Transformer2DModel from diffusers.hooks import TeaCacheConfig, apply_teacache - from diffusers.hooks.teacache import _MODEL_COEFFICIENTS + from diffusers.hooks.teacache import _get_model_config model = Lumina2Transformer2DModel( sample_size=16, @@ -361,6 +359,8 @@ def test_auto_detect_lumina2(self): num_kv_heads=1, ) + model_config = _get_model_config() + config = TeaCacheConfig(rel_l1_thresh=0.2) apply_teacache(model, config) @@ -368,7 +368,7 @@ def test_auto_detect_lumina2(self): hook = registry.get_hook("teacache") self.assertIsNotNone(hook) # Verify coefficients were auto-set - self.assertEqual(hook.coefficients, _MODEL_COEFFICIENTS["Lumina2"]) + self.assertEqual(hook.coefficients, model_config["Lumina2"]["coefficients"]) # Lumina2 doesn't have CacheMixin, manually remove hook instead registry.remove_hook("teacache") @@ -377,7 +377,7 @@ def test_auto_detect_cogvideox(self): """Test auto-detection for CogVideoX models.""" from diffusers import CogVideoXTransformer3DModel from diffusers.hooks import TeaCacheConfig, apply_teacache - from diffusers.hooks.teacache import _MODEL_COEFFICIENTS + from diffusers.hooks.teacache import _get_model_config model = CogVideoXTransformer3DModel( num_attention_heads=2, @@ -388,6 +388,8 @@ def test_auto_detect_cogvideox(self): time_embed_dim=4, ) + model_config = _get_model_config() + config = TeaCacheConfig(rel_l1_thresh=0.2) apply_teacache(model, config) @@ -395,7 +397,7 @@ def test_auto_detect_cogvideox(self): hook = registry.get_hook("teacache") self.assertIsNotNone(hook) # Verify coefficients were auto-set - self.assertEqual(hook.coefficients, _MODEL_COEFFICIENTS["CogVideoX"]) + self.assertEqual(hook.coefficients, model_config["CogVideoX"]["coefficients"]) model.disable_cache() From 883699e6a570bb22c23afad3568e8e01a4716875 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Tue, 6 Jan 2026 10:21:14 +0530 Subject: [PATCH 17/25] refactor teacache: remove closure pattern, use standalone utility functions Signed-off-by: Prajwal A --- src/diffusers/hooks/teacache.py | 305 ++++++++++++---------------- src/diffusers/models/cache_utils.py | 31 +-- tests/hooks/test_teacache.py | 17 +- 3 files changed, 148 insertions(+), 205 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index 71ec937572d6..00609a54b6b6 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -86,6 +86,91 @@ def _auto_detect_model_type(module): raise ValueError(f"TeaCache: Unsupported model '{class_name}'. Supported: {', '.join(model_config.keys())}") +# ============================================================================= +# Standalone utility functions (per DN6's feedback - no closures, direct state ops) +# ============================================================================= + + +def _rescale_distance(coefficients, x): + """Polynomial rescaling: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]""" + c = coefficients + return c[0] * x**4 + c[1] * x**3 + c[2] * x**2 + c[3] * x + c[4] + + +@torch.compiler.disable +def _should_compute(state, modulated_inp, coefficients, rel_l1_thresh): + """Determine if full computation is needed (single residual models).""" + # First timestep always computes + if state.cnt == 0: + state.accumulated_rel_l1_distance = 0 + return True + # Last timestep always computes + if state.num_steps > 0 and state.cnt == state.num_steps - 1: + state.accumulated_rel_l1_distance = 0 + return True + # No previous state - must compute + if state.previous_modulated_input is None: + return True + if state.previous_residual is None: + return True + + # Compute L1 distance and check threshold + rel_distance = ( + ((modulated_inp - state.previous_modulated_input).abs().mean() / state.previous_modulated_input.abs().mean()) + .cpu() + .item() + ) + rescaled = _rescale_distance(coefficients, rel_distance) + state.accumulated_rel_l1_distance += rescaled + + if state.accumulated_rel_l1_distance < rel_l1_thresh: + return False + + state.accumulated_rel_l1_distance = 0 + return True + + +@torch.compiler.disable +def _should_compute_dual(state, modulated_inp, coefficients, rel_l1_thresh): + """Determine if full computation is needed (dual residual models like CogVideoX).""" + # Also check encoder residual + if state.previous_residual is None or state.previous_residual_encoder is None: + return True + return _should_compute(state, modulated_inp, coefficients, rel_l1_thresh) + + +def _update_state(state, output, original_input, modulated_inp): + """Update cache state after full computation (single residual).""" + state.previous_residual = output - original_input + state.previous_modulated_input = modulated_inp + state.cnt += 1 + + +def _update_state_dual(state, hs_output, enc_output, hs_original, enc_original, modulated_inp): + """Update cache state after full computation (dual residual).""" + state.previous_residual = hs_output - hs_original + state.previous_residual_encoder = enc_output - enc_original + state.previous_modulated_input = modulated_inp + state.cnt += 1 + + +def _apply_cached_residual(state, input_tensor, modulated_inp): + """Apply cached residual - fast path (single residual).""" + output = input_tensor + state.previous_residual + state.previous_modulated_input = modulated_inp + state.cnt += 1 + return output + + +def _apply_cached_residual_dual(state, hs, enc, modulated_inp): + """Apply cached residuals - fast path (dual residual).""" + hs_out = hs + state.previous_residual + enc_out = enc + state.previous_residual_encoder + state.previous_modulated_input = modulated_inp + state.cnt += 1 + return hs_out, enc_out + + @dataclass class TeaCacheConfig: r""" @@ -138,9 +223,6 @@ class TeaCacheConfig: pipe.to("cuda") # Enable TeaCache with auto-detection (1.5x speedup) - pipe.transformer.enable_teacache(rel_l1_thresh=0.2) - - # Or with explicit config config = TeaCacheConfig(rel_l1_thresh=0.2) pipe.transformer.enable_cache(config) @@ -320,11 +402,6 @@ def __init__(self, config: TeaCacheConfig): self.model_type = None self._forward_func = None - def _rescale_distance(self, x: float) -> float: - """Evaluate polynomial: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]""" - c = self.coefficients - return c[0] * x**4 + c[1] * x**3 + c[2] * x**2 + c[3] * x + c[4] - def _maybe_reset_state_for_new_inference(self, state, module, reset_encoder_residual=False): """Reset state if we've completed all steps (start of new inference run). @@ -398,65 +475,6 @@ def initialize_hook(self, module): return module - @torch.compiler.disable - def _run_single_residual_cache( - self, - state: "TeaCacheState", - modulated_inp: torch.Tensor, - compute_fn: Callable[[], Tuple[torch.Tensor, torch.Tensor]], - cached_fn: Callable[[], torch.Tensor], - ) -> torch.Tensor: - """ - Shared cache engine for models where we cache exactly one residual tensor. - - `compute_fn` must return `(y, x_base)` where residual is `y - x_base`. - `cached_fn` must return `y` computed from the current `x` and cached residual. - """ - should_calc = self._should_compute_full_transformer(state, modulated_inp) - if (not should_calc) and (state.previous_residual is None): - should_calc = True - - if should_calc: - y, x_base = compute_fn() - state.previous_residual = y - x_base - else: - y = cached_fn() - - state.previous_modulated_input = modulated_inp - state.cnt += 1 - return y - - @torch.compiler.disable - def _run_dual_residual_cache( - self, - state: "TeaCacheState", - modulated_inp: torch.Tensor, - compute_fn: Callable[[], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], - cached_fn: Callable[[], Tuple[torch.Tensor, torch.Tensor]], - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Shared cache engine for models where we cache two residual tensors: - - hidden_states residual - - encoder_hidden_states residual - - `compute_fn` must return `(hs_y, enc_y, hs_x_base, enc_x_base)`. - `cached_fn` must return `(hs_y, enc_y)` computed using cached residuals. - """ - should_calc = self._should_compute_full_transformer(state, modulated_inp) - if (not should_calc) and (state.previous_residual is None or state.previous_residual_encoder is None): - should_calc = True - - if should_calc: - hs_y, enc_y, hs_x_base, enc_x_base = compute_fn() - state.previous_residual = hs_y - hs_x_base - state.previous_residual_encoder = enc_y - enc_x_base - else: - hs_y, enc_y = cached_fn() - - state.previous_modulated_input = modulated_inp - state.cnt += 1 - return hs_y, enc_y - def new_forward(self, module, *args, **kwargs): # Use stored forward function if available (set during initialize_hook) if self._forward_func is not None: @@ -476,46 +494,6 @@ def new_forward(self, module, *args, **kwargs): logger.warning(f"TeaCache: Unsupported model type {module_class_name}.") return self.fn_ref.original_forward(*args, **kwargs) - @torch.compiler.disable - def _should_compute_full_transformer(self, state, modulated_inp): - """ - Determine whether to compute full transformer blocks or reuse cached residual. - - This method implements the core caching decision logic from the TeaCache paper: - - Always compute first and last timesteps (for maximum quality) - - For intermediate timesteps, compute relative L1 distance between current and previous modulated inputs - - Apply polynomial rescaling to convert distance to model-specific caching signal - - Accumulate rescaled distances and compare to threshold - - Return True (compute) if accumulated distance exceeds threshold, False (cache) otherwise - """ - if state.cnt == 0: - state.accumulated_rel_l1_distance = 0 - return True - - if state.num_steps > 0 and state.cnt == state.num_steps - 1: - state.accumulated_rel_l1_distance = 0 - return True - - if state.previous_modulated_input is None: - return True - - rel_distance = ( - ( - (modulated_inp - state.previous_modulated_input).abs().mean() - / state.previous_modulated_input.abs().mean() - ) - .cpu() - .item() - ) - rescaled_distance = self._rescale_distance(rel_distance) - state.accumulated_rel_l1_distance += rescaled_distance - - if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh: - return False - - state.accumulated_rel_l1_distance = 0 - return True - def reset_state(self, module): self.state_manager.reset() return module @@ -550,48 +528,39 @@ def _flux_teacache_forward( # Inline extractor: Flux uses transformer_blocks[0].norm1 modulated_inp = module.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] - joint_attention_kwargs = kwargs.get("joint_attention_kwargs") - - def cached_fn(): - return hidden_states + state.previous_residual - - def compute_fn(): - hs = hidden_states - ori_hs = hs.clone() + # Caching decision and execution + if _should_compute(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh): + # Full computation path + ori_hs = hidden_states.clone() enc = module.context_embedder(encoder_hidden_states) - if txt_ids.ndim == 3: - txt = txt_ids[0] - else: - txt = txt_ids - if img_ids.ndim == 3: - img = img_ids[0] - else: - img = img_ids - + txt = txt_ids[0] if txt_ids.ndim == 3 else txt_ids + img = img_ids[0] if img_ids.ndim == 3 else img_ids ids = torch.cat((txt, img), dim=0) image_rotary_emb = module.pos_embed(ids) + joint_attention_kwargs = kwargs.get("joint_attention_kwargs") for block in module.transformer_blocks: - enc, hs = block( - hidden_states=hs, + enc, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=enc, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) for block in module.single_transformer_blocks: - enc, hs = block( - hidden_states=hs, + enc, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=enc, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) - return hs, ori_hs - - hidden_states = hook._run_single_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) + _update_state(state, hidden_states, ori_hs, modulated_inp) + else: + # Cached path + hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp) hidden_states = module.norm_out(hidden_states, temb) output = module.proj_out(hidden_states) @@ -655,35 +624,34 @@ def _mochi_teacache_forward( # Inline extractor: Mochi norm1 returns tuple (modulated_inp, gate_msa, scale_mlp, gate_mlp) modulated_inp = module.transformer_blocks[0].norm1(hidden_states, temb)[0] - def cached_fn(): - return hidden_states + state.previous_residual - - def compute_fn(): - hs = hidden_states - ori_hs = hs.clone() + # Caching decision and execution + if _should_compute(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh): + # Full computation path + ori_hs = hidden_states.clone() enc = encoder_hidden_states for block in module.transformer_blocks: if torch.is_grad_enabled() and module.gradient_checkpointing: - hs, enc = module._gradient_checkpointing_func( + hidden_states, enc = module._gradient_checkpointing_func( block, - hs, + hidden_states, enc, temb, encoder_attention_mask, image_rotary_emb, ) else: - hs, enc = block( - hidden_states=hs, + hidden_states, enc = block( + hidden_states=hidden_states, encoder_hidden_states=enc, temb=temb, encoder_attention_mask=encoder_attention_mask, image_rotary_emb=image_rotary_emb, ) - hs = module.norm_out(hs, temb) - return hs, ori_hs - - hidden_states = hook._run_single_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) + hidden_states = module.norm_out(hidden_states, temb) + _update_state(state, hidden_states, ori_hs, modulated_inp) + else: + # Cached path + hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp) hidden_states = module.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) @@ -781,7 +749,7 @@ def _lumina2_teacache_forward( rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item() else: rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float("inf") - rescaled_distance = hook._rescale_distance(rel_l1_change) + rescaled_distance = _rescale_distance(hook.coefficients, rel_l1_change) current_cache["accumulated_rel_l1_distance"] += rescaled_distance if current_cache["accumulated_rel_l1_distance"] < hook.config.rel_l1_thresh: should_calc = False @@ -882,36 +850,33 @@ def _cogvideox_teacache_forward( # Inline extractor: CogVideoX uses timestep embedding directly modulated_inp = emb - def cached_fn(): - return hs + state.previous_residual, enc + state.previous_residual_encoder - - def compute_fn(): - hs0 = hs - enc0 = enc - ori_hs = hs0.clone() - ori_enc = enc0.clone() - hs1, enc1 = hs0, enc0 + # Caching decision and execution (dual residual) + if _should_compute_dual(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh): + # Full computation path + ori_hs = hs.clone() + ori_enc = enc.clone() for block in module.transformer_blocks: if torch.is_grad_enabled() and module.gradient_checkpointing: ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hs1, enc1 = torch.utils.checkpoint.checkpoint( + hs, enc = torch.utils.checkpoint.checkpoint( lambda *a: block(*a), - hs1, - enc1, + hs, + enc, emb, image_rotary_emb, **ckpt_kwargs, ) else: - hs1, enc1 = block( - hidden_states=hs1, - encoder_hidden_states=enc1, + hs, enc = block( + hidden_states=hs, + encoder_hidden_states=enc, temb=emb, image_rotary_emb=image_rotary_emb, ) - return hs1, enc1, ori_hs, ori_enc - - hs, enc = hook._run_dual_residual_cache(state, modulated_inp, compute_fn=compute_fn, cached_fn=cached_fn) + _update_state_dual(state, hs, enc, ori_hs, ori_enc, modulated_inp) + else: + # Cached path + hs, enc = _apply_cached_residual_dual(state, hs, enc, modulated_inp) if not module.config.use_rotary_positional_embeddings: hs = module.norm_final(hs) @@ -955,25 +920,25 @@ def apply_teacache(module, config: TeaCacheConfig): Example: ```python from diffusers import FluxPipeline - from diffusers.hooks import TeaCacheConfig, apply_teacache + from diffusers.hooks import TeaCacheConfig # Load FLUX pipeline pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) pipe.to("cuda") - # Apply TeaCache directly to transformer + # Enable TeaCache via CacheMixin (recommended) config = TeaCacheConfig(rel_l1_thresh=0.2) - apply_teacache(pipe.transformer, config) + pipe.transformer.enable_cache(config) # Generate with caching enabled image = pipe("A cat on a windowsill", num_inference_steps=4).images[0] - # Or use the convenience method via CacheMixin - pipe.transformer.enable_teacache(rel_l1_thresh=0.2) + # Disable caching + pipe.transformer.disable_cache() ``` Note: - For most use cases, it's recommended to use the CacheMixin interface: `pipe.transformer.enable_teacache(...)` + For most use cases, it's recommended to use the CacheMixin interface: `pipe.transformer.enable_cache(...)` which provides additional convenience methods like `disable_cache()` for easy toggling. """ # Register hook on main transformer diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 03a33e32784d..1bc98bf9007f 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -68,16 +68,13 @@ def enable_cache(self, config) -> None: FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig, - TeaCacheConfig, - apply_faster_cache, - apply_first_block_cache, - apply_pyramid_attention_broadcast, - apply_teacache, TaylorSeerCacheConfig, + TeaCacheConfig, apply_faster_cache, apply_first_block_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, + apply_teacache, ) if self.is_cache_enabled: @@ -106,14 +103,14 @@ def disable_cache(self) -> None: FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, - TeaCacheConfig, TaylorSeerCacheConfig, + TeaCacheConfig, ) from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK - from ..hooks.teacache import _TEACACHE_HOOK from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK + from ..hooks.teacache import _TEACACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -153,23 +150,3 @@ def cache_context(self, name: str): yield registry._set_context(None) - - def enable_teacache(self, rel_l1_thresh: float = 0.2, num_inference_steps: int = None, **kwargs): - r""" - Enable TeaCache on the model. - - Args: - rel_l1_thresh (`float`, defaults to `0.2`): - Threshold for caching decision. Higher = more aggressive caching. - num_inference_steps (`int`, *optional*): - Total number of inference steps. Required for proper state management. - **kwargs: Additional arguments passed to TeaCacheConfig. - """ - from ..hooks import TeaCacheConfig - - config = TeaCacheConfig( - rel_l1_thresh=rel_l1_thresh, - num_inference_steps=num_inference_steps, - **kwargs - ) - self.enable_cache(config) diff --git a/tests/hooks/test_teacache.py b/tests/hooks/test_teacache.py index 333d4dbd5823..cc8a49a90eb5 100644 --- a/tests/hooks/test_teacache.py +++ b/tests/hooks/test_teacache.py @@ -142,30 +142,31 @@ def test_hook_initialization(self): self.assertIsNotNone(hook.coefficients) self.assertIsNotNone(hook.state_manager) - def test_should_compute_full_transformer_logic(self): - """Test _should_compute_full_transformer decision logic.""" - from diffusers.hooks.teacache import TeaCacheHook, TeaCacheState + def test_should_compute_logic(self): + """Test _should_compute decision logic.""" + from diffusers.hooks.teacache import TeaCacheState, _should_compute - config = TeaCacheConfig(rel_l1_thresh=1.0, coefficients=[1, 0, 0, 0, 0]) - hook = TeaCacheHook(config) + coefficients = [1, 0, 0, 0, 0] + rel_l1_thresh = 1.0 state = TeaCacheState() x0 = torch.ones(1, 4) x1 = torch.ones(1, 4) * 1.1 # First step should always compute - self.assertTrue(hook._should_compute_full_transformer(state, x0)) + self.assertTrue(_should_compute(state, x0, coefficients, rel_l1_thresh)) state.previous_modulated_input = x0 + state.previous_residual = torch.zeros(1, 4) # Need residual to skip compute state.cnt = 1 state.num_steps = 4 # Middle step: accumulate distance and stay below threshold => reuse cache - self.assertFalse(hook._should_compute_full_transformer(state, x1)) + self.assertFalse(_should_compute(state, x1, coefficients, rel_l1_thresh)) # Last step: must compute regardless of distance state.cnt = state.num_steps - 1 - self.assertTrue(hook._should_compute_full_transformer(state, x1)) + self.assertTrue(_should_compute(state, x1, coefficients, rel_l1_thresh)) def test_apply_teacache_with_custom_extractor(self): """Test apply_teacache works with custom extractor function.""" From 29110bf3d540771bf8c12941aa65335c6e45f03a Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Tue, 6 Jan 2026 10:27:05 +0530 Subject: [PATCH 18/25] cleanup: remove dead comments Signed-off-by: Prajwal A --- src/diffusers/hooks/teacache.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index 00609a54b6b6..e6b032ae0c8a 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -86,11 +86,6 @@ def _auto_detect_model_type(module): raise ValueError(f"TeaCache: Unsupported model '{class_name}'. Supported: {', '.join(model_config.keys())}") -# ============================================================================= -# Standalone utility functions (per DN6's feedback - no closures, direct state ops) -# ============================================================================= - - def _rescale_distance(coefficients, x): """Polynomial rescaling: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]""" c = coefficients From e328ccc41a268e15b89058b039484eae81e9c3b4 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Thu, 8 Jan 2026 12:25:31 +0530 Subject: [PATCH 19/25] cleanup: code quality fixes Signed-off-by: Prajwal A --- src/diffusers/hooks/__init__.py | 2 +- src/diffusers/hooks/teacache.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 198393584f54..b81d194ebd40 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -25,5 +25,5 @@ from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig - from .teacache import TeaCacheConfig, apply_teacache from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache + from .teacache import TeaCacheConfig, apply_teacache diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index e6b032ae0c8a..8e74414d8632 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -29,8 +29,8 @@ def _get_model_config(): """Get model configuration mapping. - Returns dict at runtime when forward functions are defined. - Order matters: more specific model variants must come before generic ones. + Returns dict at runtime when forward functions are defined. Order matters: more specific model variants must come + before generic ones. """ return { "FluxKontext": { @@ -202,8 +202,8 @@ class TeaCacheConfig: and statistics tracking. If not provided, TeaCache will still function correctly. num_inference_steps (`int`, *optional*, defaults to `None`): Total number of inference steps. Required for proper state management - ensures first and last timesteps - are always computed (never cached) and that state resets between inference runs. If not provided, - TeaCache will attempt to detect via callback or module attribute. + are always computed (never cached) and that state resets between inference runs. If not provided, TeaCache + will attempt to detect via callback or module attribute. num_inference_steps_callback (`Callable[[], int]`, *optional*, defaults to `None`): Callback function that returns the total number of inference steps. Alternative to `num_inference_steps` for dynamic step counts. From 8946aeb7a39c7de87eca91f655b523b30039da91 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Thu, 8 Jan 2026 11:01:50 +0000 Subject: [PATCH 20/25] refactor: improve code quality, type hints Signed-off-by: Prajwal A --- src/diffusers/hooks/teacache.py | 218 ++++++++++++++++++++------------ tests/hooks/test_teacache.py | 35 ++--- 2 files changed, 149 insertions(+), 104 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index 8e74414d8632..d23bc25950dc 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -17,15 +17,36 @@ import torch -from ..utils import get_logger +from ..models.modeling_outputs import Transformer2DModelOutput +from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from .hooks import BaseState, HookRegistry, ModelHook, StateManager -logger = get_logger(__name__) # pylint: disable=invalid-name +logger = logging.get_logger(__name__) # pylint: disable=invalid-name _TEACACHE_HOOK = "teacache" +def _handle_accelerate_hook(module: torch.nn.Module, *args, **kwargs) -> Tuple[tuple, dict]: + """Handle compatibility with accelerate's CPU offload hooks. + + When TeaCache's new_forward replaces the forward chain, accelerate's hooks are bypassed. + This function manually triggers accelerate's pre_forward to ensure proper device placement. + + Args: + module: The model module that may have accelerate hooks attached. + *args: Forward arguments to potentially move to the execution device. + **kwargs: Forward keyword arguments to potentially move to the execution device. + + Returns: + Tuple of (args, kwargs) potentially moved to the correct device. + """ + if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "pre_forward"): + # Accelerate's CpuOffload hook will move the module to GPU and return modified args/kwargs + args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs) + return args, kwargs + + def _get_model_config(): """Get model configuration mapping. @@ -110,11 +131,13 @@ def _should_compute(state, modulated_inp, coefficients, rel_l1_thresh): return True # Compute L1 distance and check threshold - rel_distance = ( - ((modulated_inp - state.previous_modulated_input).abs().mean() / state.previous_modulated_input.abs().mean()) - .cpu() - .item() - ) + # Note: .item() implicitly syncs GPU->CPU. This is necessary for the threshold comparison. + prev_mean = state.previous_modulated_input.abs().mean() + if prev_mean.item() > 1e-9: + rel_distance = ((modulated_inp - state.previous_modulated_input).abs().mean() / prev_mean).item() + else: + # Handle near-zero previous input: if current is also near-zero, no change; otherwise force recompute + rel_distance = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float("inf") rescaled = _rescale_distance(coefficients, rel_distance) state.accumulated_rel_l1_distance += rescaled @@ -169,12 +192,14 @@ def _apply_cached_residual_dual(state, hs, enc, modulated_inp): @dataclass class TeaCacheConfig: r""" - Configuration for [TeaCache](https://liewfeng.github.io/TeaCache/) applied to transformer models. + Configuration for [TeaCache](https://arxiv.org/abs/2411.19108) applied to transformer models. TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model inference by reusing transformer block computations when consecutive timestep embeddings are similar. It uses polynomial rescaling of L1 distances between modulated inputs to intelligently decide when to cache. + Reference: [TeaCache: Timestep Embedding Aware Cache for Efficient Diffusion Model Inference](https://arxiv.org/abs/2411.19108) + Currently supports: FLUX, FLUX-Kontext, Mochi, Lumina2, and CogVideoX models. Model type is auto-detected, and model-specific polynomial coefficients are automatically applied. @@ -188,15 +213,13 @@ class TeaCacheConfig: - 0.6 for ~2.0x speedup with noticeable quality loss - 0.8 for ~2.25x speedup with significant quality loss Higher thresholds lead to more aggressive caching and faster inference, but may reduce output quality. + Note: Mochi models require lower thresholds (0.06-0.09) due to different coefficient scaling. coefficients (`List[float]`, *optional*, defaults to polynomial coefficients from TeaCache paper): Polynomial coefficients used for rescaling the raw L1 distance. These coefficients transform the relative L1 distance into a model-specific caching signal. If not provided, defaults to the coefficients determined for FLUX models in the TeaCache paper: [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]. The polynomial is evaluated as: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]` where x is the relative L1 distance. - extract_modulated_input_fn (`Callable`, *optional*, defaults to auto-detection): - Function to extract modulated input from the transformer module. Takes (module, hidden_states, - timestep_emb) and returns the modulated input tensor. If not provided, auto-detects based on model type. current_timestep_callback (`Callable[[], int]`, *optional*, defaults to `None`): Callback function that returns the current timestep during inference. This is used internally for debugging and statistics tracking. If not provided, TeaCache will still function correctly. @@ -231,7 +254,6 @@ class TeaCacheConfig: rel_l1_thresh: float = 0.2 coefficients: Optional[List[float]] = None - extract_modulated_input_fn: Optional[Callable] = None current_timestep_callback: Optional[Callable[[], int]] = None num_inference_steps: Optional[int] = None num_inference_steps_callback: Optional[Callable[[], int]] = None @@ -283,8 +305,8 @@ def __repr__(self) -> str: f"TeaCacheConfig(\n" f" rel_l1_thresh={self.rel_l1_thresh},\n" f" coefficients={self.coefficients},\n" - f" extract_modulated_input_fn={self.extract_modulated_input_fn},\n" f" current_timestep_callback={self.current_timestep_callback},\n" + f" num_inference_steps={self.num_inference_steps},\n" f" num_inference_steps_callback={self.num_inference_steps_callback}\n" f")" ) @@ -325,11 +347,12 @@ def __init__(self): self.accumulated_rel_l1_distance = 0.0 self.previous_modulated_input = None self.previous_residual = None - # For models that cache both encoder and hidden_states residuals (e.g., CogVideoX) - self.previous_residual_encoder = None - # For models with variable sequence lengths (e.g., Lumina2) - self.cache_dict = {} - self.uncond_seq_len = None + # CogVideoX-specific: dual residual caching (encoder + hidden_states) + self.previous_residual_encoder = None # Only used by CogVideoX + # Lumina2-specific: per-sequence-length caching for variable sequence lengths + # Other models don't use these fields but they're allocated for simplicity + self.cache_dict = {} # Only used by Lumina2 + self.uncond_seq_len = None # Only used by Lumina2 def reset(self): """Reset all state variables to initial values for a new inference run.""" @@ -386,18 +409,15 @@ class TeaCacheHook(ModelHook): def __init__(self, config: TeaCacheConfig): super().__init__() self.config = config - # Default coefficients (will be updated in initialize_hook if needed) - # This ensures coefficients are always valid, even if initialize_hook isn't called (e.g., in tests) - self.coefficients = ( - config.coefficients - if config.coefficients - else [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01] - ) + # Coefficients will be set in initialize_hook() via auto-detection or user config + self.coefficients: Optional[List[float]] = config.coefficients self.state_manager = StateManager(TeaCacheState, (), {}) - self.model_type = None - self._forward_func = None + self.model_type: Optional[str] = None + self._forward_func: Optional[Callable] = None - def _maybe_reset_state_for_new_inference(self, state, module, reset_encoder_residual=False): + def _maybe_reset_state_for_new_inference( + self, state: TeaCacheState, module: torch.nn.Module, reset_encoder_residual: bool = False + ) -> None: """Reset state if we've completed all steps (start of new inference run). Also initializes num_steps on first timestep if not set. @@ -409,7 +429,7 @@ def _maybe_reset_state_for_new_inference(self, state, module, reset_encoder_resi """ # Reset counter if we've completed all steps (new inference run) if state.cnt == state.num_steps and state.num_steps > 0: - logger.info("TeaCache inference completed") + logger.debug("TeaCache: Inference run completed, resetting state") state.cnt = 0 state.accumulated_rel_l1_distance = 0.0 state.previous_modulated_input = None @@ -428,7 +448,7 @@ def _maybe_reset_state_for_new_inference(self, state, module, reset_encoder_resi state.num_steps = module.num_steps if state.num_steps > 0: - logger.info(f"TeaCache: Using {state.num_steps} inference steps") + logger.debug(f"TeaCache: Using {state.num_steps} inference steps") def initialize_hook(self, module): # TODO: DN6 raised concern about context setting timing. @@ -439,55 +459,34 @@ def initialize_hook(self, module): model_config = _get_model_config() # Auto-detect model type and get forward function - if self.config.extract_modulated_input_fn is None: - self.model_type = _auto_detect_model_type(module) # Raises if unsupported - self._forward_func = model_config[self.model_type]["forward_func"] - else: - # User provided custom extractor - try to detect model type for coefficients - try: - self.model_type = _auto_detect_model_type(module) - self._forward_func = model_config[self.model_type]["forward_func"] - except ValueError: - self.model_type = None - self._forward_func = None # Will fall back to class name matching in new_forward - logger.warning( - f"TeaCache: Using custom extractor for {module.__class__.__name__}. " - f"Coefficients must be provided explicitly." - ) + self.model_type = _auto_detect_model_type(module) + self._forward_func = model_config[self.model_type]["forward_func"] + + # Validate model has required attributes for TeaCache + if self.model_type in ("Flux", "FluxKontext", "Mochi"): + if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0: + raise ValueError(f"TeaCache: {self.model_type} model missing transformer_blocks") + if not hasattr(module.transformer_blocks[0], "norm1"): + raise ValueError(f"TeaCache: {self.model_type} transformer_blocks[0] missing norm1") + elif self.model_type == "Lumina2": + if not hasattr(module, "layers") or len(module.layers) == 0: + raise ValueError(f"TeaCache: Lumina2 model missing layers") + elif "CogVideoX" in self.model_type: + if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0: + raise ValueError(f"TeaCache: {self.model_type} model missing transformer_blocks") # Auto-detect coefficients if not provided by user if self.config.coefficients is None: - if self.model_type is None: - raise ValueError( - "TeaCache: Cannot auto-detect coefficients when using custom extractor. " - "Please provide coefficients explicitly in TeaCacheConfig." - ) self.coefficients = model_config[self.model_type]["coefficients"] - logger.info(f"TeaCache: Using {self.model_type} coefficients") + logger.debug(f"TeaCache: Using {self.model_type} coefficients") else: self.coefficients = self.config.coefficients - logger.info("TeaCache: Using user-provided coefficients") + logger.debug("TeaCache: Using user-provided coefficients") return module def new_forward(self, module, *args, **kwargs): - # Use stored forward function if available (set during initialize_hook) - if self._forward_func is not None: - return self._forward_func(self, module, *args, **kwargs) - - # Fallback to class name matching for backwards compatibility - module_class_name = module.__class__.__name__ - if "Flux" in module_class_name: - return _flux_teacache_forward(self, module, *args, **kwargs) - if "Mochi" in module_class_name: - return _mochi_teacache_forward(self, module, *args, **kwargs) - if "Lumina2" in module_class_name: - return _lumina2_teacache_forward(self, module, *args, **kwargs) - if "CogVideoX" in module_class_name: - return _cogvideox_teacache_forward(self, module, *args, **kwargs) - - logger.warning(f"TeaCache: Unsupported model type {module_class_name}.") - return self.fn_ref.original_forward(*args, **kwargs) + return self._forward_func(self, module, *args, **kwargs) def reset_state(self, module): self.state_manager.reset() @@ -507,7 +506,21 @@ def _flux_teacache_forward( **kwargs, ): """TeaCache forward for Flux models.""" - from diffusers.models.modeling_outputs import Transformer2DModelOutput + # Handle accelerate CPU offload compatibility - moves module and inputs to GPU if needed + args, extra_kwargs = _handle_accelerate_hook( + module, + hidden_states, + timestep, + pooled_projections, + encoder_hidden_states, + txt_ids, + img_ids, + return_dict=return_dict, + **kwargs, + ) + hidden_states, timestep, pooled_projections, encoder_hidden_states, txt_ids, img_ids = args + return_dict = extra_kwargs.pop("return_dict", return_dict) + kwargs = extra_kwargs state = hook.state_manager.get_state() hook._maybe_reset_state_for_new_inference(state, module) @@ -576,8 +589,19 @@ def _mochi_teacache_forward( return_dict: bool = True, ): """TeaCache forward for Mochi models.""" - from diffusers.models.modeling_outputs import Transformer2DModelOutput - from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers + # Handle accelerate CPU offload compatibility - moves module and inputs to GPU if needed + args, kwargs = _handle_accelerate_hook( + module, + hidden_states, + encoder_hidden_states, + timestep, + encoder_attention_mask, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + ) + hidden_states, encoder_hidden_states, timestep, encoder_attention_mask = args + attention_kwargs = kwargs.get("attention_kwargs", attention_kwargs) + return_dict = kwargs.get("return_dict", return_dict) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -642,10 +666,11 @@ def _mochi_teacache_forward( encoder_attention_mask=encoder_attention_mask, image_rotary_emb=image_rotary_emb, ) + # norm_out is included in residual (matches original TeaCache implementation) hidden_states = module.norm_out(hidden_states, temb) _update_state(state, hidden_states, ori_hs, modulated_inp) else: - # Cached path + # Cached path - residual already includes norm_out effect hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp) hidden_states = module.proj_out(hidden_states) @@ -672,8 +697,19 @@ def _lumina2_teacache_forward( return_dict: bool = True, ): """TeaCache forward for Lumina2 models (handles variable seq lens + per-len caches).""" - from diffusers.models.modeling_outputs import Transformer2DModelOutput - from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers + # Handle accelerate CPU offload compatibility - moves module and inputs to GPU if needed + args, kwargs = _handle_accelerate_hook( + module, + hidden_states, + timestep, + encoder_hidden_states, + encoder_attention_mask, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + ) + hidden_states, timestep, encoder_hidden_states, encoder_attention_mask = args + attention_kwargs = kwargs.get("attention_kwargs", attention_kwargs) + return_dict = kwargs.get("return_dict", return_dict) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -741,7 +777,7 @@ def _lumina2_teacache_forward( prev_mod_input = current_cache["previous_modulated_input"] prev_mean = prev_mod_input.abs().mean() if prev_mean.item() > 1e-9: - rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item() + rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).item() else: rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float("inf") rescaled_distance = _rescale_distance(hook.coefficients, rel_l1_change) @@ -808,8 +844,24 @@ def _cogvideox_teacache_forward( return_dict: bool = True, ): """TeaCache forward for CogVideoX models (handles dual residual caching).""" - from diffusers.models.modeling_outputs import Transformer2DModelOutput - from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers + # Handle accelerate CPU offload compatibility - moves module and inputs to GPU if needed + args, kwargs = _handle_accelerate_hook( + module, + hidden_states, + encoder_hidden_states, + timestep, + timestep_cond=timestep_cond, + ofs=ofs, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + ) + hidden_states, encoder_hidden_states, timestep = args + timestep_cond = kwargs.get("timestep_cond", timestep_cond) + ofs = kwargs.get("ofs", ofs) + image_rotary_emb = kwargs.get("image_rotary_emb", image_rotary_emb) + attention_kwargs = kwargs.get("attention_kwargs", attention_kwargs) + return_dict = kwargs.get("return_dict", return_dict) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -900,7 +952,7 @@ def _cogvideox_teacache_forward( return Transformer2DModelOutput(sample=output) -def apply_teacache(module, config: TeaCacheConfig): +def apply_teacache(module: torch.nn.Module, config: TeaCacheConfig) -> None: """ Apply TeaCache optimization to a transformer model. @@ -908,9 +960,13 @@ def apply_teacache(module, config: TeaCacheConfig): computations based on timestep embedding similarity. The hook intercepts the forward pass and implements the TeaCache algorithm to achieve 1.5x-2x speedup with minimal quality loss. + Reference: [TeaCache: Timestep Embedding Aware Cache for Efficient Diffusion Model Inference](https://arxiv.org/abs/2411.19108) + Args: - module: The transformer model to optimize (e.g., FluxTransformer2DModel, CogVideoXTransformer3DModel). - config (`TeaCacheConfig`): Configuration specifying caching threshold and optional callbacks. + module (`torch.nn.Module`): + The transformer model to optimize (e.g., FluxTransformer2DModel, CogVideoXTransformer3DModel). + config (`TeaCacheConfig`): + Configuration specifying caching threshold and optional callbacks. Example: ```python diff --git a/tests/hooks/test_teacache.py b/tests/hooks/test_teacache.py index cc8a49a90eb5..ee9eca4046ff 100644 --- a/tests/hooks/test_teacache.py +++ b/tests/hooks/test_teacache.py @@ -139,7 +139,8 @@ def test_hook_initialization(self): hook = TeaCacheHook(config) self.assertEqual(hook.config.rel_l1_thresh, 0.2) - self.assertIsNotNone(hook.coefficients) + # coefficients is None until initialize_hook() is called with a model (auto-detection) + self.assertIsNone(hook.coefficients) self.assertIsNotNone(hook.state_manager) def test_should_compute_logic(self): @@ -168,36 +169,24 @@ def test_should_compute_logic(self): state.cnt = state.num_steps - 1 self.assertTrue(_should_compute(state, x1, coefficients, rel_l1_thresh)) - def test_apply_teacache_with_custom_extractor(self): - """Test apply_teacache works with custom extractor function.""" + def test_apply_teacache_unsupported_model_raises_error(self): + """Test that apply_teacache raises error for unsupported models.""" from diffusers.hooks import apply_teacache from diffusers.models import CacheMixin - class DummyModule(torch.nn.Module, CacheMixin): + class UnsupportedModule(torch.nn.Module, CacheMixin): def __init__(self): super().__init__() self.dummy = torch.nn.Linear(4, 4) - module = DummyModule() - - # Custom extractor function - def custom_extractor(mod, hidden_states, temb): - return hidden_states - - # Must provide coefficients when using custom extractor (no auto-detection) - custom_coeffs = [1.0, 2.0, 3.0, 4.0, 5.0] - config = TeaCacheConfig( - rel_l1_thresh=0.2, extract_modulated_input_fn=custom_extractor, coefficients=custom_coeffs - ) - - # Should not raise - TeaCache works with custom extractor when coefficients provided - apply_teacache(module, config) - - # Verify registry and disable path work - registry = HookRegistry.check_if_exists_or_initialize(module) - self.assertIn("teacache", registry.hooks) + module = UnsupportedModule() + config = TeaCacheConfig(rel_l1_thresh=0.2) - module.disable_cache() + # Should raise ValueError for unsupported model type + with self.assertRaises(ValueError) as context: + apply_teacache(module, config) + self.assertIn("Unsupported model", str(context.exception)) + self.assertIn("UnsupportedModule", str(context.exception)) class TeaCacheMultiModelTests(unittest.TestCase): From abb24e0330f81b928751327ec7dddccd31cd930a Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Mon, 12 Jan 2026 09:02:16 +0000 Subject: [PATCH 21/25] fix: move TeaCache context setting to denoising loop for proper state isolation Signed-off-by: Prajwal A --- src/diffusers/hooks/hooks.py | 13 ++- src/diffusers/hooks/teacache.py | 195 ++++++++++++++------------------ tests/hooks/test_teacache.py | 76 +++++++++++++ 3 files changed, 168 insertions(+), 116 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 6e097e5882a0..293af07d15ca 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -40,11 +40,14 @@ def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None): self._current_context = None def get_state(self): - if self._current_context is None: - raise ValueError("No context is set. Please set a context before retrieving the state.") - if self._current_context not in self._state_cache.keys(): - self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs) - return self._state_cache[self._current_context] + context = self._current_context + if context is None: + # Fallback to default context for backward compatibility with + # pipelines that don't call cache_context() + context = "_default" + if context not in self._state_cache: + self._state_cache[context] = self._state_cls(*self._init_args, **self._init_kwargs) + return self._state_cache[context] def set_context(self, name: str) -> None: self._current_context = name diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index d23bc25950dc..5f9eda09c764 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -28,25 +28,20 @@ def _handle_accelerate_hook(module: torch.nn.Module, *args, **kwargs) -> Tuple[tuple, dict]: - """Handle compatibility with accelerate's CPU offload hooks. - - When TeaCache's new_forward replaces the forward chain, accelerate's hooks are bypassed. - This function manually triggers accelerate's pre_forward to ensure proper device placement. - - Args: - module: The model module that may have accelerate hooks attached. - *args: Forward arguments to potentially move to the execution device. - **kwargs: Forward keyword arguments to potentially move to the execution device. - - Returns: - Tuple of (args, kwargs) potentially moved to the correct device. - """ + """Handle compatibility with accelerate's CPU offload hooks.""" if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "pre_forward"): - # Accelerate's CpuOffload hook will move the module to GPU and return modified args/kwargs args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs) return args, kwargs +def _extract_lora_scale(attention_kwargs: Optional[Dict[str, Any]]) -> Tuple[Optional[Dict[str, Any]], float]: + """Extract LoRA scale from attention kwargs, returning (modified_kwargs, scale).""" + if attention_kwargs is None: + return None, 1.0 + attention_kwargs = attention_kwargs.copy() + return attention_kwargs, attention_kwargs.pop("scale", 1.0) + + def _get_model_config(): """Get model configuration mapping. @@ -113,39 +108,35 @@ def _rescale_distance(coefficients, x): return c[0] * x**4 + c[1] * x**3 + c[2] * x**2 + c[3] * x + c[4] +def _compute_rel_l1_distance(current: torch.Tensor, previous: torch.Tensor) -> float: + """Compute relative L1 distance between current and previous tensors.""" + prev_mean = previous.abs().mean() + if prev_mean.item() > 1e-9: + return ((current - previous).abs().mean() / prev_mean).item() + # Near-zero previous: if current also near-zero, no change; otherwise force recompute + return 0.0 if current.abs().mean().item() < 1e-9 else float("inf") + + @torch.compiler.disable def _should_compute(state, modulated_inp, coefficients, rel_l1_thresh): """Determine if full computation is needed (single residual models).""" - # First timestep always computes - if state.cnt == 0: - state.accumulated_rel_l1_distance = 0 - return True - # Last timestep always computes - if state.num_steps > 0 and state.cnt == state.num_steps - 1: + # First/last timesteps and missing state always require computation + is_first_step = state.cnt == 0 + is_last_step = state.num_steps > 0 and state.cnt == state.num_steps - 1 + missing_state = state.previous_modulated_input is None or state.previous_residual is None + + if is_first_step or is_last_step or missing_state: state.accumulated_rel_l1_distance = 0 return True - # No previous state - must compute - if state.previous_modulated_input is None: - return True - if state.previous_residual is None: - return True - - # Compute L1 distance and check threshold - # Note: .item() implicitly syncs GPU->CPU. This is necessary for the threshold comparison. - prev_mean = state.previous_modulated_input.abs().mean() - if prev_mean.item() > 1e-9: - rel_distance = ((modulated_inp - state.previous_modulated_input).abs().mean() / prev_mean).item() - else: - # Handle near-zero previous input: if current is also near-zero, no change; otherwise force recompute - rel_distance = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float("inf") - rescaled = _rescale_distance(coefficients, rel_distance) - state.accumulated_rel_l1_distance += rescaled - if state.accumulated_rel_l1_distance < rel_l1_thresh: - return False + # Compute accumulated distance and check threshold + rel_distance = _compute_rel_l1_distance(modulated_inp, state.previous_modulated_input) + state.accumulated_rel_l1_distance += _rescale_distance(coefficients, rel_distance) - state.accumulated_rel_l1_distance = 0 - return True + if state.accumulated_rel_l1_distance >= rel_l1_thresh: + state.accumulated_rel_l1_distance = 0 + return True + return False @torch.compiler.disable @@ -259,7 +250,11 @@ class TeaCacheConfig: num_inference_steps_callback: Optional[Callable[[], int]] = None def __post_init__(self): - # Validate rel_l1_thresh + self._validate_threshold() + self._validate_coefficients() + + def _validate_threshold(self): + """Validate rel_l1_thresh parameter.""" if not isinstance(self.rel_l1_thresh, (int, float)): raise TypeError( f"rel_l1_thresh must be a number, got {type(self.rel_l1_thresh).__name__}. " @@ -282,23 +277,25 @@ def __post_init__(self): f"Consider using values between 0.1 and 0.6 for better quality-speed tradeoff." ) - # Validate coefficients only if explicitly provided (None = auto-detect later) - if self.coefficients is not None: - if not isinstance(self.coefficients, (list, tuple)): - raise TypeError( - f"coefficients must be a list or tuple, got {type(self.coefficients).__name__}. " - f"Please provide a list of 5 polynomial coefficients." - ) - if len(self.coefficients) != 5: - raise ValueError( - f"coefficients must contain exactly 5 elements for 4th-degree polynomial, " - f"got {len(self.coefficients)}. The polynomial is evaluated as: " - f"c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]" - ) - if not all(isinstance(c, (int, float)) for c in self.coefficients): - raise TypeError( - f"All coefficients must be numbers. Got types: {[type(c).__name__ for c in self.coefficients]}" - ) + def _validate_coefficients(self): + """Validate coefficients parameter if provided.""" + if self.coefficients is None: + return + if not isinstance(self.coefficients, (list, tuple)): + raise TypeError( + f"coefficients must be a list or tuple, got {type(self.coefficients).__name__}. " + f"Please provide a list of 5 polynomial coefficients." + ) + if len(self.coefficients) != 5: + raise ValueError( + f"coefficients must contain exactly 5 elements for 4th-degree polynomial, " + f"got {len(self.coefficients)}. The polynomial is evaluated as: " + f"c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]" + ) + if not all(isinstance(c, (int, float)) for c in self.coefficients): + raise TypeError( + f"All coefficients must be numbers. Got types: {[type(c).__name__ for c in self.coefficients]}" + ) def __repr__(self) -> str: return ( @@ -451,11 +448,9 @@ def _maybe_reset_state_for_new_inference( logger.debug(f"TeaCache: Using {state.num_steps} inference steps") def initialize_hook(self, module): - # TODO: DN6 raised concern about context setting timing. - # Currently set in initialize_hook(). Should this be in denoising loop instead? - # See PR #12652 for discussion. Keeping current behavior pending clarification. - self.state_manager.set_context("teacache") - + # Context is set by pipeline's cache_context() calls in the denoising loop. + # This enables proper state isolation between cond/uncond branches. + # See PR #12652 for discussion on this design decision. model_config = _get_model_config() # Auto-detect model type and get forward function @@ -589,7 +584,6 @@ def _mochi_teacache_forward( return_dict: bool = True, ): """TeaCache forward for Mochi models.""" - # Handle accelerate CPU offload compatibility - moves module and inputs to GPU if needed args, kwargs = _handle_accelerate_hook( module, hidden_states, @@ -603,12 +597,7 @@ def _mochi_teacache_forward( attention_kwargs = kwargs.get("attention_kwargs", attention_kwargs) return_dict = kwargs.get("return_dict", return_dict) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - + attention_kwargs, lora_scale = _extract_lora_scale(attention_kwargs) if USE_PEFT_BACKEND: scale_lora_layers(module, lora_scale) @@ -697,7 +686,6 @@ def _lumina2_teacache_forward( return_dict: bool = True, ): """TeaCache forward for Lumina2 models (handles variable seq lens + per-len caches).""" - # Handle accelerate CPU offload compatibility - moves module and inputs to GPU if needed args, kwargs = _handle_accelerate_hook( module, hidden_states, @@ -711,12 +699,7 @@ def _lumina2_teacache_forward( attention_kwargs = kwargs.get("attention_kwargs", attention_kwargs) return_dict = kwargs.get("return_dict", return_dict) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - + attention_kwargs, lora_scale = _extract_lora_scale(attention_kwargs) if USE_PEFT_BACKEND: scale_lora_layers(module, lora_scale) @@ -760,6 +743,7 @@ def _lumina2_teacache_forward( # Inline extractor: Lumina2 uses layers[0].norm1 modulated_inp = module.layers[0].norm1(input_to_main_loop, temb)[0] + # Per-sequence-length caching for variable sequence lengths cache_key = max_seq_len if cache_key not in state.cache_dict: state.cache_dict[cache_key] = { @@ -767,32 +751,27 @@ def _lumina2_teacache_forward( "previous_residual": None, "accumulated_rel_l1_distance": 0.0, } - current_cache = state.cache_dict[cache_key] + cache = state.cache_dict[cache_key] - if state.cnt == 0 or state.cnt == state.num_steps - 1: + # Determine if computation is needed + is_boundary_step = state.cnt == 0 or state.cnt == state.num_steps - 1 + has_previous = cache["previous_modulated_input"] is not None + + if is_boundary_step or not has_previous: should_calc = True - current_cache["accumulated_rel_l1_distance"] = 0.0 + cache["accumulated_rel_l1_distance"] = 0.0 else: - if current_cache["previous_modulated_input"] is not None: - prev_mod_input = current_cache["previous_modulated_input"] - prev_mean = prev_mod_input.abs().mean() - if prev_mean.item() > 1e-9: - rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).item() - else: - rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float("inf") - rescaled_distance = _rescale_distance(hook.coefficients, rel_l1_change) - current_cache["accumulated_rel_l1_distance"] += rescaled_distance - if current_cache["accumulated_rel_l1_distance"] < hook.config.rel_l1_thresh: - should_calc = False - else: - should_calc = True - current_cache["accumulated_rel_l1_distance"] = 0.0 - else: + rel_distance = _compute_rel_l1_distance(modulated_inp, cache["previous_modulated_input"]) + cache["accumulated_rel_l1_distance"] += _rescale_distance(hook.coefficients, rel_distance) + if cache["accumulated_rel_l1_distance"] >= hook.config.rel_l1_thresh: should_calc = True - current_cache["accumulated_rel_l1_distance"] = 0.0 + cache["accumulated_rel_l1_distance"] = 0.0 + else: + should_calc = False - current_cache["previous_modulated_input"] = modulated_inp.clone() + cache["previous_modulated_input"] = modulated_inp.clone() + # Track sequence length for step counting (CFG handling) if state.uncond_seq_len is None: state.uncond_seq_len = cache_key if cache_key != state.uncond_seq_len: @@ -800,16 +779,16 @@ def _lumina2_teacache_forward( if state.cnt >= state.num_steps: state.cnt = 0 - if not should_calc and current_cache["previous_residual"] is not None: - processed_hidden_states = input_to_main_loop + current_cache["previous_residual"] + # Apply cached residual or compute full forward + if not should_calc and cache["previous_residual"] is not None: + processed_hidden_states = input_to_main_loop + cache["previous_residual"] else: - current_processing_states = input_to_main_loop + processed_hidden_states = input_to_main_loop for layer in module.layers: - current_processing_states = layer( - current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb + processed_hidden_states = layer( + processed_hidden_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb ) - processed_hidden_states = current_processing_states - current_cache["previous_residual"] = processed_hidden_states - input_to_main_loop + cache["previous_residual"] = processed_hidden_states - input_to_main_loop output_after_norm = module.norm_out(processed_hidden_states, temb) p = module.config.patch_size @@ -844,7 +823,6 @@ def _cogvideox_teacache_forward( return_dict: bool = True, ): """TeaCache forward for CogVideoX models (handles dual residual caching).""" - # Handle accelerate CPU offload compatibility - moves module and inputs to GPU if needed args, kwargs = _handle_accelerate_hook( module, hidden_states, @@ -863,12 +841,7 @@ def _cogvideox_teacache_forward( attention_kwargs = kwargs.get("attention_kwargs", attention_kwargs) return_dict = kwargs.get("return_dict", return_dict) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - + attention_kwargs, lora_scale = _extract_lora_scale(attention_kwargs) if USE_PEFT_BACKEND: scale_lora_layers(module, lora_scale) diff --git a/tests/hooks/test_teacache.py b/tests/hooks/test_teacache.py index ee9eca4046ff..bc74ad849f2e 100644 --- a/tests/hooks/test_teacache.py +++ b/tests/hooks/test_teacache.py @@ -456,5 +456,81 @@ def test_model_routing(self): self.assertEqual(cogvideox_hook.model_type, "CogVideoX") +class StateManagerContextTests(unittest.TestCase): + """Tests for StateManager context isolation and backward compatibility.""" + + def test_context_isolation(self): + """Test that different contexts maintain separate states.""" + from diffusers.hooks import StateManager + from diffusers.hooks.teacache import TeaCacheState + + state_manager = StateManager(TeaCacheState, (), {}) + + # Set context "cond" and modify state + state_manager.set_context("cond") + cond_state = state_manager.get_state() + cond_state.cnt = 5 + cond_state.accumulated_rel_l1_distance = 0.3 + + # Set context "uncond" and modify state + state_manager.set_context("uncond") + uncond_state = state_manager.get_state() + uncond_state.cnt = 10 + uncond_state.accumulated_rel_l1_distance = 0.7 + + # Verify isolation - switch back to "cond" + state_manager.set_context("cond") + self.assertEqual(state_manager.get_state().cnt, 5) + self.assertEqual(state_manager.get_state().accumulated_rel_l1_distance, 0.3) + + # Verify isolation - switch back to "uncond" + state_manager.set_context("uncond") + self.assertEqual(state_manager.get_state().cnt, 10) + self.assertEqual(state_manager.get_state().accumulated_rel_l1_distance, 0.7) + + def test_default_context_fallback(self): + """Test that state works without explicit context (backward compatibility).""" + from diffusers.hooks import StateManager + from diffusers.hooks.teacache import TeaCacheState + + state_manager = StateManager(TeaCacheState, (), {}) + + # Don't set context - should use "_default" fallback + state = state_manager.get_state() + self.assertIsNotNone(state) + self.assertEqual(state.cnt, 0) + + # Modify state + state.cnt = 42 + + # Should still get the same state via default context + state2 = state_manager.get_state() + self.assertEqual(state2.cnt, 42) + + def test_default_context_separate_from_named(self): + """Test that default context is separate from named contexts.""" + from diffusers.hooks import StateManager + from diffusers.hooks.teacache import TeaCacheState + + state_manager = StateManager(TeaCacheState, (), {}) + + # Use default context (no explicit set_context) + default_state = state_manager.get_state() + default_state.cnt = 100 + + # Now set a named context + state_manager.set_context("named") + named_state = state_manager.get_state() + named_state.cnt = 200 + + # Clear context to use default again + state_manager._current_context = None + self.assertEqual(state_manager.get_state().cnt, 100) + + # Named context should still have its value + state_manager.set_context("named") + self.assertEqual(state_manager.get_state().cnt, 200) + + if __name__ == "__main__": unittest.main() From c45af5101b4241ba409a6ee95e66f1544c90e255 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Mon, 12 Jan 2026 10:43:10 +0000 Subject: [PATCH 22/25] simplify implementation and reduce code duplication Signed-off-by: Prajwal A --- src/diffusers/hooks/__init__.py | 2 +- src/diffusers/hooks/hooks.py | 20 +-- src/diffusers/hooks/teacache.py | 229 +++++++++----------------------- tests/hooks/test_teacache.py | 5 - 4 files changed, 66 insertions(+), 190 deletions(-) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index b81d194ebd40..f8bd28464cfd 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -20,7 +20,7 @@ from .faster_cache import FasterCacheConfig, apply_faster_cache from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading - from .hooks import HookRegistry, ModelHook + from .hooks import HookRegistry, ModelHook, StateManager from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 293af07d15ca..b461ed18a696 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -136,7 +136,7 @@ def reset_state(self, module: torch.nn.Module): return module def _set_context(self, module: torch.nn.Module, name: str) -> None: - # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them. + """Set context on all StateManager attributes of this hook.""" for attr_name in dir(self): attr = getattr(self, attr_name) if isinstance(attr, StateManager): @@ -145,22 +145,12 @@ def _set_context(self, module: torch.nn.Module, name: str) -> None: class HookFunctionReference: - def __init__(self) -> None: - """A container class that maintains mutable references to forward pass functions in a hook chain. - - Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the - entire forward pass structure. + """Mutable container for forward pass function references in a hook chain. - Attributes: - pre_forward: A callable that processes inputs before the main forward pass. - post_forward: A callable that processes outputs after the main forward pass. - forward: The current forward function in the hook chain. - original_forward: The original forward function, stored when a hook provides a custom new_forward. + Enables dynamic modification of the execution chain without rebuilding. + """ - The class enables hook removal by allowing updates to the forward chain through reference modification rather - than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to - be updated, preserving the execution order of the remaining hooks. - """ + def __init__(self) -> None: self.pre_forward = None self.post_forward = None self.forward = None diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index 5f9eda09c764..379d0a5479f6 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -28,14 +28,14 @@ def _handle_accelerate_hook(module: torch.nn.Module, *args, **kwargs) -> Tuple[tuple, dict]: - """Handle compatibility with accelerate's CPU offload hooks.""" + """Handle accelerate CPU offload hook compatibility.""" if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "pre_forward"): args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs) return args, kwargs def _extract_lora_scale(attention_kwargs: Optional[Dict[str, Any]]) -> Tuple[Optional[Dict[str, Any]], float]: - """Extract LoRA scale from attention kwargs, returning (modified_kwargs, scale).""" + """Extract LoRA scale from attention kwargs.""" if attention_kwargs is None: return None, 1.0 attention_kwargs = attention_kwargs.copy() @@ -43,11 +43,7 @@ def _extract_lora_scale(attention_kwargs: Optional[Dict[str, Any]]) -> Tuple[Opt def _get_model_config(): - """Get model configuration mapping. - - Returns dict at runtime when forward functions are defined. Order matters: more specific model variants must come - before generic ones. - """ + """Get model configuration mapping. Order matters: more specific variants before generic ones.""" return { "FluxKontext": { "forward_func": _flux_teacache_forward, @@ -94,7 +90,6 @@ def _auto_detect_model_type(module): config_path = getattr(getattr(module, "config", None), "_name_or_path", "").lower() model_config = _get_model_config() - # Check config path first (for variants), then class name (ordered most specific first) for model_type in model_config: if model_type.lower() in config_path or model_type in class_name: return model_type @@ -109,18 +104,16 @@ def _rescale_distance(coefficients, x): def _compute_rel_l1_distance(current: torch.Tensor, previous: torch.Tensor) -> float: - """Compute relative L1 distance between current and previous tensors.""" + """Compute relative L1 distance between tensors.""" prev_mean = previous.abs().mean() if prev_mean.item() > 1e-9: return ((current - previous).abs().mean() / prev_mean).item() - # Near-zero previous: if current also near-zero, no change; otherwise force recompute return 0.0 if current.abs().mean().item() < 1e-9 else float("inf") @torch.compiler.disable def _should_compute(state, modulated_inp, coefficients, rel_l1_thresh): """Determine if full computation is needed (single residual models).""" - # First/last timesteps and missing state always require computation is_first_step = state.cnt == 0 is_last_step = state.num_steps > 0 and state.cnt == state.num_steps - 1 missing_state = state.previous_modulated_input is None or state.previous_residual is None @@ -129,7 +122,6 @@ def _should_compute(state, modulated_inp, coefficients, rel_l1_thresh): state.accumulated_rel_l1_distance = 0 return True - # Compute accumulated distance and check threshold rel_distance = _compute_rel_l1_distance(modulated_inp, state.previous_modulated_input) state.accumulated_rel_l1_distance += _rescale_distance(coefficients, rel_distance) @@ -142,21 +134,20 @@ def _should_compute(state, modulated_inp, coefficients, rel_l1_thresh): @torch.compiler.disable def _should_compute_dual(state, modulated_inp, coefficients, rel_l1_thresh): """Determine if full computation is needed (dual residual models like CogVideoX).""" - # Also check encoder residual if state.previous_residual is None or state.previous_residual_encoder is None: return True return _should_compute(state, modulated_inp, coefficients, rel_l1_thresh) def _update_state(state, output, original_input, modulated_inp): - """Update cache state after full computation (single residual).""" + """Update cache state after full computation.""" state.previous_residual = output - original_input state.previous_modulated_input = modulated_inp state.cnt += 1 def _update_state_dual(state, hs_output, enc_output, hs_original, enc_original, modulated_inp): - """Update cache state after full computation (dual residual).""" + """Update cache state after full computation (dual residual for CogVideoX).""" state.previous_residual = hs_output - hs_original state.previous_residual_encoder = enc_output - enc_original state.previous_modulated_input = modulated_inp @@ -164,7 +155,7 @@ def _update_state_dual(state, hs_output, enc_output, hs_original, enc_original, def _apply_cached_residual(state, input_tensor, modulated_inp): - """Apply cached residual - fast path (single residual).""" + """Apply cached residual (fast path).""" output = input_tensor + state.previous_residual state.previous_modulated_input = modulated_inp state.cnt += 1 @@ -172,7 +163,7 @@ def _apply_cached_residual(state, input_tensor, modulated_inp): def _apply_cached_residual_dual(state, hs, enc, modulated_inp): - """Apply cached residuals - fast path (dual residual).""" + """Apply cached residuals (fast path for CogVideoX).""" hs_out = hs + state.previous_residual enc_out = enc + state.previous_residual_encoder state.previous_modulated_input = modulated_inp @@ -183,63 +174,42 @@ def _apply_cached_residual_dual(state, hs, enc, modulated_inp): @dataclass class TeaCacheConfig: r""" - Configuration for [TeaCache](https://arxiv.org/abs/2411.19108) applied to transformer models. - - TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model inference - by reusing transformer block computations when consecutive timestep embeddings are similar. It uses polynomial - rescaling of L1 distances between modulated inputs to intelligently decide when to cache. + Configuration for [TeaCache](https://huggingface.co/papers/2411.19108). - Reference: [TeaCache: Timestep Embedding Aware Cache for Efficient Diffusion Model Inference](https://arxiv.org/abs/2411.19108) + TeaCache (Timestep Embedding Aware Cache) speeds up diffusion model inference by reusing transformer block + computations when consecutive timestep embeddings are similar. It uses polynomial rescaling of L1 distances + between modulated inputs to decide when to cache. - Currently supports: FLUX, FLUX-Kontext, Mochi, Lumina2, and CogVideoX models. Model type is auto-detected, and - model-specific polynomial coefficients are automatically applied. + Currently supports: FLUX, FLUX-Kontext, Mochi, Lumina2, and CogVideoX models. Model type is auto-detected. - Args: + Attributes: rel_l1_thresh (`float`, defaults to `0.2`): - Threshold for accumulated relative L1 distance. When the accumulated distance is below this threshold, the - cached residual from the previous timestep is reused instead of computing the full transformer. Based on - the original TeaCache paper, values in the range [0.1, 0.3] work best for balancing speed and quality: - - 0.25 for ~1.5x speedup with minimal quality loss - - 0.4 for ~1.8x speedup with slight quality loss - - 0.6 for ~2.0x speedup with noticeable quality loss - - 0.8 for ~2.25x speedup with significant quality loss - Higher thresholds lead to more aggressive caching and faster inference, but may reduce output quality. - Note: Mochi models require lower thresholds (0.06-0.09) due to different coefficient scaling. - coefficients (`List[float]`, *optional*, defaults to polynomial coefficients from TeaCache paper): - Polynomial coefficients used for rescaling the raw L1 distance. These coefficients transform the relative - L1 distance into a model-specific caching signal. If not provided, defaults to the coefficients determined - for FLUX models in the TeaCache paper: [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, - 2.64230861e-01]. The polynomial is evaluated as: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]` where x - is the relative L1 distance. - current_timestep_callback (`Callable[[], int]`, *optional*, defaults to `None`): - Callback function that returns the current timestep during inference. This is used internally for debugging - and statistics tracking. If not provided, TeaCache will still function correctly. - num_inference_steps (`int`, *optional*, defaults to `None`): - Total number of inference steps. Required for proper state management - ensures first and last timesteps - are always computed (never cached) and that state resets between inference runs. If not provided, TeaCache - will attempt to detect via callback or module attribute. - num_inference_steps_callback (`Callable[[], int]`, *optional*, defaults to `None`): - Callback function that returns the total number of inference steps. Alternative to `num_inference_steps` - for dynamic step counts. + Threshold for accumulated relative L1 distance. When below this threshold, the cached residual is reused. + Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, 0.6 for ~2.0x. Mochi models require lower + thresholds (0.06-0.09). + coefficients (`List[float]`, *optional*): + Polynomial coefficients for rescaling L1 distance. Auto-detected based on model type if not provided. + Evaluated as: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]`. + current_timestep_callback (`Callable[[], int]`, *optional*): + Callback returning current timestep. Used for debugging/statistics. + num_inference_steps (`int`, *optional*): + Total inference steps. Ensures first/last timesteps are always computed. Auto-detected if not provided. + num_inference_steps_callback (`Callable[[], int]`, *optional*): + Callback returning total inference steps. Alternative to `num_inference_steps`. Example: ```python - from diffusers import FluxPipeline - from diffusers.hooks import TeaCacheConfig - - # Load FLUX pipeline - pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) - pipe.to("cuda") + >>> from diffusers import FluxPipeline + >>> from diffusers.hooks import TeaCacheConfig - # Enable TeaCache with auto-detection (1.5x speedup) - config = TeaCacheConfig(rel_l1_thresh=0.2) - pipe.transformer.enable_cache(config) + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") - # Generate image with caching - image = pipe("A cat sitting on a windowsill", num_inference_steps=4).images[0] + >>> config = TeaCacheConfig(rel_l1_thresh=0.2) + >>> pipe.transformer.enable_cache(config) - # Disable caching - pipe.transformer.disable_cache() + >>> image = pipe("A cat sitting on a windowsill", num_inference_steps=4).images[0] + >>> pipe.transformer.disable_cache() ``` """ @@ -310,32 +280,10 @@ def __repr__(self) -> str: class TeaCacheState(BaseState): - """ - State management for TeaCache hook. - - This class tracks the caching state across diffusion timesteps, managing counters, accumulated distances, and - cached values needed for the TeaCache algorithm. The state persists across multiple forward passes during a single - inference run and is automatically reset when a new inference begins. + r""" + State for [TeaCache](https://huggingface.co/papers/2411.19108). - Attributes: - cnt (int): - Current timestep counter, incremented with each forward pass. Used to identify first/last timesteps which - are always computed (never cached) for maximum quality. - num_steps (int): - Total number of inference steps for the current run. Used to identify the last timestep. Automatically - detected from callbacks or pipeline attributes if not explicitly set. - accumulated_rel_l1_distance (float): - Running accumulator for rescaled L1 distances between consecutive modulated inputs. Compared against the - threshold to make caching decisions. Reset to 0 when the decision is made to recompute. - previous_modulated_input (torch.Tensor): - Modulated input from the previous timestep, extracted from the first transformer block's norm1 layer. Used - for computing L1 distance to determine similarity between consecutive timesteps. - previous_residual (torch.Tensor): - Cached residual (output - input) from the previous timestep's full transformer computation. Applied - directly when caching is triggered instead of computing all transformer blocks. - previous_residual_encoder (torch.Tensor, optional): - Cached encoder residual for models that cache both encoder and hidden_states residuals (e.g., CogVideoX). - None for models that only cache hidden_states residual. + Tracks caching state across diffusion timesteps, including counters, accumulated distances, and cached residuals. """ def __init__(self): @@ -375,30 +323,11 @@ def __repr__(self) -> str: class TeaCacheHook(ModelHook): - """ - ModelHook implementing TeaCache for transformer models. - - This hook intercepts transformer forward pass and implements adaptive caching based on timestep embedding - similarity. It extracts modulated inputs, computes L1 distances, applies polynomial rescaling, and decides whether - to reuse cached residuals or compute full transformer blocks. - - The hook follows the original TeaCache algorithm from the paper: - 1. Extract modulated input using provided extractor function - 2. Compute relative L1 distance between current and previous modulated inputs - 3. Apply polynomial rescaling with model-specific coefficients to the distance - 4. Accumulate rescaled distances and compare to threshold - 5. If below threshold: reuse cached residual (fast path, skip transformer computation) - 6. If above threshold: compute full transformer blocks and cache new residual (slow path) - - The first and last timesteps are always computed fully (never cached) to ensure maximum quality. + r""" + Hook implementing [TeaCache](https://huggingface.co/papers/2411.19108) for transformer models. - Attributes: - config (TeaCacheConfig): - Configuration containing threshold, polynomial coefficients, and optional callbacks. - coefficients (List[float]): - Polynomial coefficients for rescaling L1 distances (auto-detected or user-provided). - state_manager (StateManager): - Manages TeaCacheState across forward passes, maintaining counters and cached values. + Intercepts transformer forward pass and implements adaptive caching based on timestep embedding similarity. + First and last timesteps are always computed fully (never cached) to ensure maximum quality. """ _is_stateful = True @@ -415,16 +344,8 @@ def __init__(self, config: TeaCacheConfig): def _maybe_reset_state_for_new_inference( self, state: TeaCacheState, module: torch.nn.Module, reset_encoder_residual: bool = False ) -> None: - """Reset state if we've completed all steps (start of new inference run). - - Also initializes num_steps on first timestep if not set. - - Args: - state: TeaCacheState instance. - module: The transformer module. - reset_encoder_residual: If True, also reset previous_residual_encoder (for CogVideoX). - """ - # Reset counter if we've completed all steps (new inference run) + """Reset state if inference run completed. Initialize num_steps on first timestep if not set.""" + # Reset if we've completed all steps (new inference run) if state.cnt == state.num_steps and state.num_steps > 0: logger.debug("TeaCache: Inference run completed, resetting state") state.cnt = 0 @@ -434,9 +355,8 @@ def _maybe_reset_state_for_new_inference( if reset_encoder_residual: state.previous_residual_encoder = None - # Set num_steps on first timestep if not already set + # Set num_steps on first timestep (priority: config > callback > module attribute) if state.cnt == 0 and state.num_steps == 0: - # Priority: config value > callback > module attribute if self.config.num_inference_steps is not None: state.num_steps = self.config.num_inference_steps elif self.config.num_inference_steps_callback is not None: @@ -465,7 +385,7 @@ def initialize_hook(self, module): raise ValueError(f"TeaCache: {self.model_type} transformer_blocks[0] missing norm1") elif self.model_type == "Lumina2": if not hasattr(module, "layers") or len(module.layers) == 0: - raise ValueError(f"TeaCache: Lumina2 model missing layers") + raise ValueError("TeaCache: Lumina2 model missing layers") elif "CogVideoX" in self.model_type: if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0: raise ValueError(f"TeaCache: {self.model_type} model missing transformer_blocks") @@ -501,7 +421,6 @@ def _flux_teacache_forward( **kwargs, ): """TeaCache forward for Flux models.""" - # Handle accelerate CPU offload compatibility - moves module and inputs to GPU if needed args, extra_kwargs = _handle_accelerate_hook( module, hidden_states, @@ -529,12 +448,9 @@ def _flux_teacache_forward( else: temb = module.time_text_embed(timestep_scaled, pooled_projections) - # Inline extractor: Flux uses transformer_blocks[0].norm1 modulated_inp = module.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] - # Caching decision and execution if _should_compute(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh): - # Full computation path ori_hs = hidden_states.clone() enc = module.context_embedder(encoder_hidden_states) @@ -562,7 +478,6 @@ def _flux_teacache_forward( ) _update_state(state, hidden_states, ori_hs, modulated_inp) else: - # Cached path hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp) hidden_states = module.norm_out(hidden_states, temb) @@ -629,12 +544,9 @@ def _mochi_teacache_forward( dtype=torch.float32, ) - # Inline extractor: Mochi norm1 returns tuple (modulated_inp, gate_msa, scale_mlp, gate_mlp) modulated_inp = module.transformer_blocks[0].norm1(hidden_states, temb)[0] - # Caching decision and execution if _should_compute(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh): - # Full computation path ori_hs = hidden_states.clone() enc = encoder_hidden_states for block in module.transformer_blocks: @@ -655,11 +567,9 @@ def _mochi_teacache_forward( encoder_attention_mask=encoder_attention_mask, image_rotary_emb=image_rotary_emb, ) - # norm_out is included in residual (matches original TeaCache implementation) hidden_states = module.norm_out(hidden_states, temb) _update_state(state, hidden_states, ori_hs, modulated_inp) else: - # Cached path - residual already includes norm_out effect hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp) hidden_states = module.proj_out(hidden_states) @@ -685,7 +595,7 @@ def _lumina2_teacache_forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): - """TeaCache forward for Lumina2 models (handles variable seq lens + per-len caches).""" + """TeaCache forward for Lumina2 models.""" args, kwargs = _handle_accelerate_hook( module, hidden_states, @@ -740,7 +650,6 @@ def _lumina2_teacache_forward( mask[i, :seq_len_val] = True attention_mask_for_main_loop_arg = mask - # Inline extractor: Lumina2 uses layers[0].norm1 modulated_inp = module.layers[0].norm1(input_to_main_loop, temb)[0] # Per-sequence-length caching for variable sequence lengths @@ -753,7 +662,6 @@ def _lumina2_teacache_forward( } cache = state.cache_dict[cache_key] - # Determine if computation is needed is_boundary_step = state.cnt == 0 or state.cnt == state.num_steps - 1 has_previous = cache["previous_modulated_input"] is not None @@ -771,7 +679,7 @@ def _lumina2_teacache_forward( cache["previous_modulated_input"] = modulated_inp.clone() - # Track sequence length for step counting (CFG handling) + # Track sequence length for step counting (CFG) if state.uncond_seq_len is None: state.uncond_seq_len = cache_key if cache_key != state.uncond_seq_len: @@ -822,7 +730,7 @@ def _cogvideox_teacache_forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): - """TeaCache forward for CogVideoX models (handles dual residual caching).""" + """TeaCache forward for CogVideoX models.""" args, kwargs = _handle_accelerate_hook( module, hidden_states, @@ -867,12 +775,9 @@ def _cogvideox_teacache_forward( enc = hs[:, :text_seq_length] hs = hs[:, text_seq_length:] - # Inline extractor: CogVideoX uses timestep embedding directly modulated_inp = emb - # Caching decision and execution (dual residual) if _should_compute_dual(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh): - # Full computation path ori_hs = hs.clone() ori_enc = enc.clone() for block in module.transformer_blocks: @@ -895,7 +800,6 @@ def _cogvideox_teacache_forward( ) _update_state_dual(state, hs, enc, ori_hs, ori_enc, modulated_inp) else: - # Cached path hs, enc = _apply_cached_residual_dual(state, hs, enc, modulated_inp) if not module.config.use_rotary_positional_embeddings: @@ -926,44 +830,31 @@ def _cogvideox_teacache_forward( def apply_teacache(module: torch.nn.Module, config: TeaCacheConfig) -> None: - """ - Apply TeaCache optimization to a transformer model. - - This function registers a TeaCacheHook on the provided transformer, enabling adaptive caching of transformer block - computations based on timestep embedding similarity. The hook intercepts the forward pass and implements the - TeaCache algorithm to achieve 1.5x-2x speedup with minimal quality loss. + r""" + Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given module. - Reference: [TeaCache: Timestep Embedding Aware Cache for Efficient Diffusion Model Inference](https://arxiv.org/abs/2411.19108) + TeaCache speeds up diffusion model inference (1.5x-2x) by caching transformer block computations when consecutive + timestep embeddings are similar. Model type is auto-detected based on the module class name. Args: module (`torch.nn.Module`): The transformer model to optimize (e.g., FluxTransformer2DModel, CogVideoXTransformer3DModel). config (`TeaCacheConfig`): - Configuration specifying caching threshold and optional callbacks. + The configuration to use for TeaCache. Example: ```python - from diffusers import FluxPipeline - from diffusers.hooks import TeaCacheConfig + >>> import torch + >>> from diffusers import FluxPipeline + >>> from diffusers.hooks import TeaCacheConfig, apply_teacache - # Load FLUX pipeline - pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) - pipe.to("cuda") + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") - # Enable TeaCache via CacheMixin (recommended) - config = TeaCacheConfig(rel_l1_thresh=0.2) - pipe.transformer.enable_cache(config) + >>> apply_teacache(pipe.transformer, TeaCacheConfig(rel_l1_thresh=0.2)) - # Generate with caching enabled - image = pipe("A cat on a windowsill", num_inference_steps=4).images[0] - - # Disable caching - pipe.transformer.disable_cache() + >>> image = pipe("A cat on a windowsill", num_inference_steps=4).images[0] ``` - - Note: - For most use cases, it's recommended to use the CacheMixin interface: `pipe.transformer.enable_cache(...)` - which provides additional convenience methods like `disable_cache()` for easy toggling. """ # Register hook on main transformer registry = HookRegistry.check_if_exists_or_initialize(module) diff --git a/tests/hooks/test_teacache.py b/tests/hooks/test_teacache.py index bc74ad849f2e..5549260d2171 100644 --- a/tests/hooks/test_teacache.py +++ b/tests/hooks/test_teacache.py @@ -55,13 +55,11 @@ def test_invalid_coefficients_type(self): def test_very_low_threshold_accepted(self): """Test very low threshold is accepted (with logging warning).""" - # Very low threshold should be accepted but logged as warning config = TeaCacheConfig(rel_l1_thresh=0.01) self.assertEqual(config.rel_l1_thresh, 0.01) def test_very_high_threshold_accepted(self): """Test very high threshold is accepted (with logging warning).""" - # Very high threshold should be accepted but logged as warning config = TeaCacheConfig(rel_l1_thresh=1.5) self.assertEqual(config.rel_l1_thresh, 1.5) @@ -98,17 +96,14 @@ def test_state_reset(self): from diffusers.hooks.teacache import TeaCacheState state = TeaCacheState() - # Modify state state.cnt = 5 state.num_steps = 10 state.accumulated_rel_l1_distance = 0.5 state.previous_modulated_input = torch.randn(1, 10) state.previous_residual = torch.randn(1, 10) - # Reset state.reset() - # Verify reset self.assertEqual(state.cnt, 0) self.assertEqual(state.num_steps, 0) self.assertEqual(state.accumulated_rel_l1_distance, 0.0) From 90eb7469562203674f0877ba7ec2496936015c23 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Mon, 12 Jan 2026 10:53:35 +0000 Subject: [PATCH 23/25] simplify TeaCache tests with module-level imports and model factory helpers Signed-off-by: Prajwal A --- tests/hooks/test_teacache.py | 229 ++++++++--------------------------- 1 file changed, 53 insertions(+), 176 deletions(-) diff --git a/tests/hooks/test_teacache.py b/tests/hooks/test_teacache.py index 5549260d2171..90edfd733f89 100644 --- a/tests/hooks/test_teacache.py +++ b/tests/hooks/test_teacache.py @@ -16,7 +16,45 @@ import torch -from diffusers.hooks import HookRegistry, TeaCacheConfig +from diffusers import CogVideoXTransformer3DModel, Lumina2Transformer2DModel, MochiTransformer3DModel +from diffusers.hooks import HookRegistry, StateManager, TeaCacheConfig, apply_teacache +from diffusers.hooks.teacache import TeaCacheHook, TeaCacheState, _get_model_config, _should_compute + + +def _create_mochi_model() -> MochiTransformer3DModel: + return MochiTransformer3DModel( + patch_size=2, + num_attention_heads=2, + attention_head_dim=8, + num_layers=2, + in_channels=4, + text_embed_dim=16, + time_embed_dim=4, + ) + + +def _create_lumina2_model() -> Lumina2Transformer2DModel: + return Lumina2Transformer2DModel( + sample_size=16, + patch_size=2, + in_channels=4, + hidden_size=24, + num_layers=2, + num_refiner_layers=1, + num_attention_heads=3, + num_kv_heads=1, + ) + + +def _create_cogvideox_model() -> CogVideoXTransformer3DModel: + return CogVideoXTransformer3DModel( + num_attention_heads=2, + attention_head_dim=8, + in_channels=4, + num_layers=2, + text_embed_dim=16, + time_embed_dim=4, + ) class TeaCacheConfigTests(unittest.TestCase): @@ -82,8 +120,6 @@ class TeaCacheStateTests(unittest.TestCase): def test_state_initialization(self): """Test state initializes with correct default values.""" - from diffusers.hooks.teacache import TeaCacheState - state = TeaCacheState() self.assertEqual(state.cnt, 0) self.assertEqual(state.num_steps, 0) @@ -93,8 +129,6 @@ def test_state_initialization(self): def test_state_reset(self): """Test state reset clears all values.""" - from diffusers.hooks.teacache import TeaCacheState - state = TeaCacheState() state.cnt = 5 state.num_steps = 10 @@ -112,8 +146,6 @@ def test_state_reset(self): def test_state_repr(self): """Test __repr__ method works correctly.""" - from diffusers.hooks.teacache import TeaCacheState - state = TeaCacheState() state.cnt = 3 state.num_steps = 10 @@ -128,20 +160,15 @@ class TeaCacheHookTests(unittest.TestCase): def test_hook_initialization(self): """Test hook initializes correctly with config.""" - from diffusers.hooks.teacache import TeaCacheHook - config = TeaCacheConfig(rel_l1_thresh=0.2) hook = TeaCacheHook(config) self.assertEqual(hook.config.rel_l1_thresh, 0.2) - # coefficients is None until initialize_hook() is called with a model (auto-detection) self.assertIsNone(hook.coefficients) self.assertIsNotNone(hook.state_manager) def test_should_compute_logic(self): """Test _should_compute decision logic.""" - from diffusers.hooks.teacache import TeaCacheState, _should_compute - coefficients = [1, 0, 0, 0, 0] rel_l1_thresh = 1.0 state = TeaCacheState() @@ -149,24 +176,20 @@ def test_should_compute_logic(self): x0 = torch.ones(1, 4) x1 = torch.ones(1, 4) * 1.1 - # First step should always compute self.assertTrue(_should_compute(state, x0, coefficients, rel_l1_thresh)) state.previous_modulated_input = x0 - state.previous_residual = torch.zeros(1, 4) # Need residual to skip compute + state.previous_residual = torch.zeros(1, 4) state.cnt = 1 state.num_steps = 4 - # Middle step: accumulate distance and stay below threshold => reuse cache self.assertFalse(_should_compute(state, x1, coefficients, rel_l1_thresh)) - # Last step: must compute regardless of distance state.cnt = state.num_steps - 1 self.assertTrue(_should_compute(state, x1, coefficients, rel_l1_thresh)) def test_apply_teacache_unsupported_model_raises_error(self): """Test that apply_teacache raises error for unsupported models.""" - from diffusers.hooks import apply_teacache from diffusers.models import CacheMixin class UnsupportedModule(torch.nn.Module, CacheMixin): @@ -177,7 +200,6 @@ def __init__(self): module = UnsupportedModule() config = TeaCacheConfig(rel_l1_thresh=0.2) - # Should raise ValueError for unsupported model type with self.assertRaises(ValueError) as context: apply_teacache(module, config) self.assertIn("Unsupported model", str(context.exception)) @@ -189,8 +211,6 @@ class TeaCacheMultiModelTests(unittest.TestCase): def test_model_coefficient_registry(self): """Test that model coefficients are properly registered.""" - from diffusers.hooks.teacache import _get_model_config - model_config = _get_model_config() self.assertIn("Flux", model_config) @@ -198,7 +218,6 @@ def test_model_coefficient_registry(self): self.assertIn("Lumina2", model_config) self.assertIn("CogVideoX", model_config) - # Verify all coefficients are 5-element lists for model_name, config in model_config.items(): coeffs = config["coefficients"] self.assertEqual(len(coeffs), 5, f"{model_name} coefficients should have 5 elements") @@ -207,26 +226,14 @@ def test_model_coefficient_registry(self): ) def test_mochi_extractor(self): - """Test Mochi modulated input extraction (now inlined in forward function).""" - from diffusers import MochiTransformer3DModel - - # Create a minimal Mochi model for testing - model = MochiTransformer3DModel( - patch_size=2, - num_attention_heads=2, - attention_head_dim=8, - num_layers=2, - in_channels=4, - text_embed_dim=16, - time_embed_dim=4, - ) + """Test Mochi modulated input extraction.""" + model = _create_mochi_model() hidden_states = torch.randn(2, 4, 2, 8, 8) timestep = torch.randint(0, 1000, (2,)) encoder_hidden_states = torch.randn(2, 16, 16) encoder_attention_mask = torch.ones(2, 16).bool() - # Get timestep embedding temb, _ = model.time_embed( timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype ) @@ -234,116 +241,58 @@ def test_mochi_extractor(self): hidden_states = model.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (2, -1)).flatten(1, 2) - # Test inline extraction logic: Mochi norm1 returns tuple (modulated_inp, ...) modulated_inp = model.transformer_blocks[0].norm1(hidden_states, temb)[0] self.assertIsInstance(modulated_inp, torch.Tensor) self.assertEqual(modulated_inp.shape[0], hidden_states.shape[0]) def test_lumina2_extractor(self): - """Test Lumina2 modulated input extraction (now inlined in forward function).""" - from diffusers import Lumina2Transformer2DModel - - # Create a minimal Lumina2 model for testing - model = Lumina2Transformer2DModel( - sample_size=16, - patch_size=2, - in_channels=4, - hidden_size=24, - num_layers=2, - num_refiner_layers=1, - num_attention_heads=3, - num_kv_heads=1, - ) + """Test Lumina2 modulated input extraction.""" + model = _create_lumina2_model() - # Create properly shaped inputs that match what the extractor expects - # The extractor expects input_to_main_loop (already preprocessed concatenated text+image tokens) batch_size = 2 - seq_len = 100 # combined text + image sequence + seq_len = 100 hidden_size = model.config.hidden_size - # Simulate input_to_main_loop (already preprocessed) input_to_main_loop = torch.randn(batch_size, seq_len, hidden_size) temb = torch.randn(batch_size, hidden_size) - # Test inline extraction logic: Lumina2 uses layers[0].norm1 modulated_inp = model.layers[0].norm1(input_to_main_loop, temb)[0] self.assertIsInstance(modulated_inp, torch.Tensor) self.assertEqual(modulated_inp.shape[0], batch_size) def test_cogvideox_extractor(self): - """Test CogVideoX modulated input extraction (now inlined in forward function).""" - from diffusers import CogVideoXTransformer3DModel - - # Create a minimal CogVideoX model for testing - model = CogVideoXTransformer3DModel( - num_attention_heads=2, - attention_head_dim=8, - in_channels=4, - num_layers=2, - text_embed_dim=16, - time_embed_dim=4, - ) + """Test CogVideoX modulated input extraction.""" + model = _create_cogvideox_model() hidden_states = torch.randn(2, 2, 4, 8, 8) timestep = torch.randint(0, 1000, (2,)) - # Get timestep embedding t_emb = model.time_proj(timestep) t_emb = t_emb.to(dtype=hidden_states.dtype) emb = model.time_embedding(t_emb, None) - # Test inline extraction logic: CogVideoX uses timestep embedding directly modulated_inp = emb self.assertIsInstance(modulated_inp, torch.Tensor) self.assertEqual(modulated_inp.shape, emb.shape) def test_auto_detect_mochi(self): """Test auto-detection for Mochi models.""" - from diffusers import MochiTransformer3DModel - from diffusers.hooks import TeaCacheConfig, apply_teacache - from diffusers.hooks.teacache import _get_model_config - - model = MochiTransformer3DModel( - patch_size=2, - num_attention_heads=2, - attention_head_dim=8, - num_layers=2, - in_channels=4, - text_embed_dim=16, - time_embed_dim=4, - ) - + model = _create_mochi_model() model_config = _get_model_config() - # Test coefficient auto-detection config = TeaCacheConfig(rel_l1_thresh=0.2) apply_teacache(model, config) registry = HookRegistry.check_if_exists_or_initialize(model) hook = registry.get_hook("teacache") self.assertIsNotNone(hook) - # Verify coefficients were auto-set self.assertEqual(hook.coefficients, model_config["Mochi"]["coefficients"]) model.disable_cache() def test_auto_detect_lumina2(self): """Test auto-detection for Lumina2 models.""" - from diffusers import Lumina2Transformer2DModel - from diffusers.hooks import TeaCacheConfig, apply_teacache - from diffusers.hooks.teacache import _get_model_config - - model = Lumina2Transformer2DModel( - sample_size=16, - patch_size=2, - in_channels=4, - hidden_size=24, - num_layers=2, - num_refiner_layers=1, - num_attention_heads=3, - num_kv_heads=1, - ) - + model = _create_lumina2_model() model_config = _get_model_config() config = TeaCacheConfig(rel_l1_thresh=0.2) @@ -352,27 +301,13 @@ def test_auto_detect_lumina2(self): registry = HookRegistry.check_if_exists_or_initialize(model) hook = registry.get_hook("teacache") self.assertIsNotNone(hook) - # Verify coefficients were auto-set self.assertEqual(hook.coefficients, model_config["Lumina2"]["coefficients"]) - # Lumina2 doesn't have CacheMixin, manually remove hook instead registry.remove_hook("teacache") def test_auto_detect_cogvideox(self): """Test auto-detection for CogVideoX models.""" - from diffusers import CogVideoXTransformer3DModel - from diffusers.hooks import TeaCacheConfig, apply_teacache - from diffusers.hooks.teacache import _get_model_config - - model = CogVideoXTransformer3DModel( - num_attention_heads=2, - attention_head_dim=8, - in_channels=4, - num_layers=2, - text_embed_dim=16, - time_embed_dim=4, - ) - + model = _create_cogvideox_model() model_config = _get_model_config() config = TeaCacheConfig(rel_l1_thresh=0.2) @@ -381,73 +316,35 @@ def test_auto_detect_cogvideox(self): registry = HookRegistry.check_if_exists_or_initialize(model) hook = registry.get_hook("teacache") self.assertIsNotNone(hook) - # Verify coefficients were auto-set self.assertEqual(hook.coefficients, model_config["CogVideoX"]["coefficients"]) model.disable_cache() def test_teacache_state_encoder_residual(self): """Test that TeaCacheState supports encoder residual for CogVideoX.""" - from diffusers.hooks.teacache import TeaCacheState - state = TeaCacheState() self.assertIsNone(state.previous_residual_encoder) - # Set encoder residual state.previous_residual_encoder = torch.randn(2, 10, 16) self.assertIsNotNone(state.previous_residual_encoder) - # Reset should clear it state.reset() self.assertIsNone(state.previous_residual_encoder) def test_model_routing(self): """Test that new_forward routes to correct handler based on model type.""" - from diffusers import CogVideoXTransformer3DModel, Lumina2Transformer2DModel, MochiTransformer3DModel - from diffusers.hooks.teacache import TeaCacheConfig, TeaCacheHook - config = TeaCacheConfig(rel_l1_thresh=0.2) - # Test Mochi routing - mochi_model = MochiTransformer3DModel( - patch_size=2, - num_attention_heads=2, - attention_head_dim=8, - num_layers=2, - in_channels=4, - text_embed_dim=16, - time_embed_dim=4, - ) mochi_hook = TeaCacheHook(config) - mochi_hook.initialize_hook(mochi_model) + mochi_hook.initialize_hook(_create_mochi_model()) self.assertEqual(mochi_hook.model_type, "Mochi") - # Test Lumina2 routing - lumina_model = Lumina2Transformer2DModel( - sample_size=16, - patch_size=2, - in_channels=4, - hidden_size=24, - num_layers=2, - num_refiner_layers=1, - num_attention_heads=3, - num_kv_heads=1, - ) lumina_hook = TeaCacheHook(config) - lumina_hook.initialize_hook(lumina_model) + lumina_hook.initialize_hook(_create_lumina2_model()) self.assertEqual(lumina_hook.model_type, "Lumina2") - # Test CogVideoX routing - cogvideox_model = CogVideoXTransformer3DModel( - num_attention_heads=2, - attention_head_dim=8, - in_channels=4, - num_layers=2, - text_embed_dim=16, - time_embed_dim=4, - ) cogvideox_hook = TeaCacheHook(config) - cogvideox_hook.initialize_hook(cogvideox_model) + cogvideox_hook.initialize_hook(_create_cogvideox_model()) self.assertEqual(cogvideox_hook.model_type, "CogVideoX") @@ -456,73 +353,53 @@ class StateManagerContextTests(unittest.TestCase): def test_context_isolation(self): """Test that different contexts maintain separate states.""" - from diffusers.hooks import StateManager - from diffusers.hooks.teacache import TeaCacheState - state_manager = StateManager(TeaCacheState, (), {}) - # Set context "cond" and modify state state_manager.set_context("cond") cond_state = state_manager.get_state() cond_state.cnt = 5 cond_state.accumulated_rel_l1_distance = 0.3 - # Set context "uncond" and modify state state_manager.set_context("uncond") uncond_state = state_manager.get_state() uncond_state.cnt = 10 uncond_state.accumulated_rel_l1_distance = 0.7 - # Verify isolation - switch back to "cond" state_manager.set_context("cond") self.assertEqual(state_manager.get_state().cnt, 5) self.assertEqual(state_manager.get_state().accumulated_rel_l1_distance, 0.3) - # Verify isolation - switch back to "uncond" state_manager.set_context("uncond") self.assertEqual(state_manager.get_state().cnt, 10) self.assertEqual(state_manager.get_state().accumulated_rel_l1_distance, 0.7) def test_default_context_fallback(self): """Test that state works without explicit context (backward compatibility).""" - from diffusers.hooks import StateManager - from diffusers.hooks.teacache import TeaCacheState - state_manager = StateManager(TeaCacheState, (), {}) - # Don't set context - should use "_default" fallback state = state_manager.get_state() self.assertIsNotNone(state) self.assertEqual(state.cnt, 0) - # Modify state state.cnt = 42 - # Should still get the same state via default context state2 = state_manager.get_state() self.assertEqual(state2.cnt, 42) def test_default_context_separate_from_named(self): """Test that default context is separate from named contexts.""" - from diffusers.hooks import StateManager - from diffusers.hooks.teacache import TeaCacheState - state_manager = StateManager(TeaCacheState, (), {}) - # Use default context (no explicit set_context) default_state = state_manager.get_state() default_state.cnt = 100 - # Now set a named context state_manager.set_context("named") named_state = state_manager.get_state() named_state.cnt = 200 - # Clear context to use default again state_manager._current_context = None self.assertEqual(state_manager.get_state().cnt, 100) - # Named context should still have its value state_manager.set_context("named") self.assertEqual(state_manager.get_state().cnt, 200) From b38315d78ecfa84b763732c486e3f7f44093b0e6 Mon Sep 17 00:00:00 2001 From: Prajwal A Date: Mon, 12 Jan 2026 11:15:44 +0000 Subject: [PATCH 24/25] refactor: enhance type hints and add ControlNet support in TeaCache forward methods Signed-off-by: Prajwal A --- src/diffusers/hooks/teacache.py | 53 ++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index 379d0a5479f6..cc23132c3f6b 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -42,7 +42,7 @@ def _extract_lora_scale(attention_kwargs: Optional[Dict[str, Any]]) -> Tuple[Opt return attention_kwargs, attention_kwargs.pop("scale", 1.0) -def _get_model_config(): +def _get_model_config() -> Dict[str, Dict[str, Any]]: """Get model configuration mapping. Order matters: more specific variants before generic ones.""" return { "FluxKontext": { @@ -84,7 +84,7 @@ def _get_model_config(): } -def _auto_detect_model_type(module): +def _auto_detect_model_type(module: torch.nn.Module) -> str: """Auto-detect model type from class name and config path.""" class_name = module.__class__.__name__ config_path = getattr(getattr(module, "config", None), "_name_or_path", "").lower() @@ -97,7 +97,7 @@ def _auto_detect_model_type(module): raise ValueError(f"TeaCache: Unsupported model '{class_name}'. Supported: {', '.join(model_config.keys())}") -def _rescale_distance(coefficients, x): +def _rescale_distance(coefficients: List[float], x: float) -> float: """Polynomial rescaling: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]""" c = coefficients return c[0] * x**4 + c[1] * x**3 + c[2] * x**2 + c[3] * x + c[4] @@ -417,10 +417,13 @@ def _flux_teacache_forward( encoder_hidden_states: torch.Tensor, txt_ids: torch.Tensor, img_ids: torch.Tensor, + controlnet_block_samples: Optional[List[torch.Tensor]] = None, + controlnet_single_block_samples: Optional[List[torch.Tensor]] = None, return_dict: bool = True, + controlnet_blocks_repeat: bool = False, **kwargs, ): - """TeaCache forward for Flux models.""" + """TeaCache forward for Flux models with ControlNet support.""" args, extra_kwargs = _handle_accelerate_hook( module, hidden_states, @@ -429,10 +432,18 @@ def _flux_teacache_forward( encoder_hidden_states, txt_ids, img_ids, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, return_dict=return_dict, + controlnet_blocks_repeat=controlnet_blocks_repeat, **kwargs, ) hidden_states, timestep, pooled_projections, encoder_hidden_states, txt_ids, img_ids = args + controlnet_block_samples = extra_kwargs.pop("controlnet_block_samples", controlnet_block_samples) + controlnet_single_block_samples = extra_kwargs.pop( + "controlnet_single_block_samples", controlnet_single_block_samples + ) + controlnet_blocks_repeat = extra_kwargs.pop("controlnet_blocks_repeat", controlnet_blocks_repeat) return_dict = extra_kwargs.pop("return_dict", return_dict) kwargs = extra_kwargs @@ -460,7 +471,7 @@ def _flux_teacache_forward( image_rotary_emb = module.pos_embed(ids) joint_attention_kwargs = kwargs.get("joint_attention_kwargs") - for block in module.transformer_blocks: + for index_block, block in enumerate(module.transformer_blocks): enc, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=enc, @@ -468,7 +479,20 @@ def _flux_teacache_forward( image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) - for block in module.single_transformer_blocks: + # ControlNet residual + if controlnet_block_samples is not None: + interval_control = len(module.transformer_blocks) / len(controlnet_block_samples) + interval_control = ( + int(interval_control) if interval_control == int(interval_control) else int(interval_control) + 1 + ) + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + for index_block, block in enumerate(module.single_transformer_blocks): enc, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=enc, @@ -476,6 +500,14 @@ def _flux_teacache_forward( image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) + # ControlNet residual + if controlnet_single_block_samples is not None: + interval_control = len(module.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = ( + int(interval_control) if interval_control == int(interval_control) else int(interval_control) + 1 + ) + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] + _update_state(state, hidden_states, ori_hs, modulated_inp) else: hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp) @@ -595,7 +627,14 @@ def _lumina2_teacache_forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): - """TeaCache forward for Lumina2 models.""" + """TeaCache forward for Lumina2 models. + + Note: Lumina2 uses inline caching logic instead of `_should_compute()` because it requires + per-sequence-length caching for variable sequence lengths (CFG batches have different lengths). + Each sequence length gets its own cache entry in `state.cache_dict`. + + Note: Gradient checkpointing is not supported in this TeaCache implementation for Lumina2. + """ args, kwargs = _handle_accelerate_hook( module, hidden_states, From 08deb44cd4dab20f4ddf3edeb21fdc9d8acf3768 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 20 Jan 2026 09:10:59 +0000 Subject: [PATCH 25/25] Apply style fixes --- src/diffusers/hooks/teacache.py | 14 +++++++------- .../pipeline_controlnet_sd_xl_img2img.py | 1 - .../hidream_image/pipeline_hidream_image.py | 1 - .../pag/pipeline_pag_controlnet_sd_xl_img2img.py | 1 - .../pipeline_stable_diffusion_latent_upscale.py | 1 - 5 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py index cc23132c3f6b..201d798aeebf 100644 --- a/src/diffusers/hooks/teacache.py +++ b/src/diffusers/hooks/teacache.py @@ -177,8 +177,8 @@ class TeaCacheConfig: Configuration for [TeaCache](https://huggingface.co/papers/2411.19108). TeaCache (Timestep Embedding Aware Cache) speeds up diffusion model inference by reusing transformer block - computations when consecutive timestep embeddings are similar. It uses polynomial rescaling of L1 distances - between modulated inputs to decide when to cache. + computations when consecutive timestep embeddings are similar. It uses polynomial rescaling of L1 distances between + modulated inputs to decide when to cache. Currently supports: FLUX, FLUX-Kontext, Mochi, Lumina2, and CogVideoX models. Model type is auto-detected. @@ -326,8 +326,8 @@ class TeaCacheHook(ModelHook): r""" Hook implementing [TeaCache](https://huggingface.co/papers/2411.19108) for transformer models. - Intercepts transformer forward pass and implements adaptive caching based on timestep embedding similarity. - First and last timesteps are always computed fully (never cached) to ensure maximum quality. + Intercepts transformer forward pass and implements adaptive caching based on timestep embedding similarity. First + and last timesteps are always computed fully (never cached) to ensure maximum quality. """ _is_stateful = True @@ -629,9 +629,9 @@ def _lumina2_teacache_forward( ): """TeaCache forward for Lumina2 models. - Note: Lumina2 uses inline caching logic instead of `_should_compute()` because it requires - per-sequence-length caching for variable sequence lengths (CFG batches have different lengths). - Each sequence length gets its own cache entry in `state.cache_dict`. + Note: Lumina2 uses inline caching logic instead of `_should_compute()` because it requires per-sequence-length + caching for variable sequence lengths (CFG batches have different lengths). Each sequence length gets its own cache + entry in `state.cache_dict`. Note: Gradient checkpointing is not supported in this TeaCache implementation for Lumina2. """ diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 94c4c394465b..2ea7307fec32 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -84,7 +84,6 @@ >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image - >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py index d259f7ee7865..b41d9772a7cc 100644 --- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -53,7 +53,6 @@ >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> from diffusers import HiDreamImagePipeline - >>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( ... "meta-llama/Meta-Llama-3.1-8B-Instruct", diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index df5b3f5c10a5..5a6b8d5e9f37 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -85,7 +85,6 @@ >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image - >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 66d5ffa6b849..a1d0407caf5e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -459,7 +459,6 @@ def __call__( >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline >>> import torch - >>> pipeline = StableDiffusionPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... )