diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index eb12b8a52a1e..f8bd28464cfd 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -20,9 +20,10 @@ from .faster_cache import FasterCacheConfig, apply_faster_cache from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading - from .hooks import HookRegistry, ModelHook + from .hooks import HookRegistry, ModelHook, StateManager from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache + from .teacache import TeaCacheConfig, apply_teacache diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 6e097e5882a0..b461ed18a696 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -40,11 +40,14 @@ def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None): self._current_context = None def get_state(self): - if self._current_context is None: - raise ValueError("No context is set. Please set a context before retrieving the state.") - if self._current_context not in self._state_cache.keys(): - self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs) - return self._state_cache[self._current_context] + context = self._current_context + if context is None: + # Fallback to default context for backward compatibility with + # pipelines that don't call cache_context() + context = "_default" + if context not in self._state_cache: + self._state_cache[context] = self._state_cls(*self._init_args, **self._init_kwargs) + return self._state_cache[context] def set_context(self, name: str) -> None: self._current_context = name @@ -133,7 +136,7 @@ def reset_state(self, module: torch.nn.Module): return module def _set_context(self, module: torch.nn.Module, name: str) -> None: - # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them. + """Set context on all StateManager attributes of this hook.""" for attr_name in dir(self): attr = getattr(self, attr_name) if isinstance(attr, StateManager): @@ -142,22 +145,12 @@ def _set_context(self, module: torch.nn.Module, name: str) -> None: class HookFunctionReference: - def __init__(self) -> None: - """A container class that maintains mutable references to forward pass functions in a hook chain. - - Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the - entire forward pass structure. + """Mutable container for forward pass function references in a hook chain. - Attributes: - pre_forward: A callable that processes inputs before the main forward pass. - post_forward: A callable that processes outputs after the main forward pass. - forward: The current forward function in the hook chain. - original_forward: The original forward function, stored when a hook provides a custom new_forward. + Enables dynamic modification of the execution chain without rebuilding. + """ - The class enables hook removal by allowing updates to the forward chain through reference modification rather - than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to - be updated, preserving the execution order of the remaining hooks. - """ + def __init__(self) -> None: self.pre_forward = None self.post_forward = None self.forward = None diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py new file mode 100644 index 000000000000..201d798aeebf --- /dev/null +++ b/src/diffusers/hooks/teacache.py @@ -0,0 +1,901 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from ..models.modeling_outputs import Transformer2DModelOutput +from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from .hooks import BaseState, HookRegistry, ModelHook, StateManager + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +_TEACACHE_HOOK = "teacache" + + +def _handle_accelerate_hook(module: torch.nn.Module, *args, **kwargs) -> Tuple[tuple, dict]: + """Handle accelerate CPU offload hook compatibility.""" + if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "pre_forward"): + args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs) + return args, kwargs + + +def _extract_lora_scale(attention_kwargs: Optional[Dict[str, Any]]) -> Tuple[Optional[Dict[str, Any]], float]: + """Extract LoRA scale from attention kwargs.""" + if attention_kwargs is None: + return None, 1.0 + attention_kwargs = attention_kwargs.copy() + return attention_kwargs, attention_kwargs.pop("scale", 1.0) + + +def _get_model_config() -> Dict[str, Dict[str, Any]]: + """Get model configuration mapping. Order matters: more specific variants before generic ones.""" + return { + "FluxKontext": { + "forward_func": _flux_teacache_forward, + "coefficients": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02], + }, + "Flux": { + "forward_func": _flux_teacache_forward, + "coefficients": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01], + }, + "Mochi": { + "forward_func": _mochi_teacache_forward, + "coefficients": [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03], + }, + "Lumina2": { + "forward_func": _lumina2_teacache_forward, + "coefficients": [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344], + }, + "CogVideoX1.5-5B-I2V": { + "forward_func": _cogvideox_teacache_forward, + "coefficients": [1.22842302e02, -1.04088754e02, 2.62981677e01, -3.06001e-01, 3.71213220e-02], + }, + "CogVideoX1.5-5B": { + "forward_func": _cogvideox_teacache_forward, + "coefficients": [2.50210439e02, -1.65061612e02, 3.57804877e01, -7.81551492e-01, 3.58559703e-02], + }, + "CogVideoX-2b": { + "forward_func": _cogvideox_teacache_forward, + "coefficients": [-3.10658903e01, 2.54732368e01, -5.92380459e00, 1.75769064e00, -3.61568434e-03], + }, + "CogVideoX-5b": { + "forward_func": _cogvideox_teacache_forward, + "coefficients": [-1.53880483e03, 8.43202495e02, -1.34363087e02, 7.97131516e00, -5.23162339e-02], + }, + "CogVideoX": { + "forward_func": _cogvideox_teacache_forward, + "coefficients": [-1.53880483e03, 8.43202495e02, -1.34363087e02, 7.97131516e00, -5.23162339e-02], + }, + } + + +def _auto_detect_model_type(module: torch.nn.Module) -> str: + """Auto-detect model type from class name and config path.""" + class_name = module.__class__.__name__ + config_path = getattr(getattr(module, "config", None), "_name_or_path", "").lower() + model_config = _get_model_config() + + for model_type in model_config: + if model_type.lower() in config_path or model_type in class_name: + return model_type + + raise ValueError(f"TeaCache: Unsupported model '{class_name}'. Supported: {', '.join(model_config.keys())}") + + +def _rescale_distance(coefficients: List[float], x: float) -> float: + """Polynomial rescaling: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]""" + c = coefficients + return c[0] * x**4 + c[1] * x**3 + c[2] * x**2 + c[3] * x + c[4] + + +def _compute_rel_l1_distance(current: torch.Tensor, previous: torch.Tensor) -> float: + """Compute relative L1 distance between tensors.""" + prev_mean = previous.abs().mean() + if prev_mean.item() > 1e-9: + return ((current - previous).abs().mean() / prev_mean).item() + return 0.0 if current.abs().mean().item() < 1e-9 else float("inf") + + +@torch.compiler.disable +def _should_compute(state, modulated_inp, coefficients, rel_l1_thresh): + """Determine if full computation is needed (single residual models).""" + is_first_step = state.cnt == 0 + is_last_step = state.num_steps > 0 and state.cnt == state.num_steps - 1 + missing_state = state.previous_modulated_input is None or state.previous_residual is None + + if is_first_step or is_last_step or missing_state: + state.accumulated_rel_l1_distance = 0 + return True + + rel_distance = _compute_rel_l1_distance(modulated_inp, state.previous_modulated_input) + state.accumulated_rel_l1_distance += _rescale_distance(coefficients, rel_distance) + + if state.accumulated_rel_l1_distance >= rel_l1_thresh: + state.accumulated_rel_l1_distance = 0 + return True + return False + + +@torch.compiler.disable +def _should_compute_dual(state, modulated_inp, coefficients, rel_l1_thresh): + """Determine if full computation is needed (dual residual models like CogVideoX).""" + if state.previous_residual is None or state.previous_residual_encoder is None: + return True + return _should_compute(state, modulated_inp, coefficients, rel_l1_thresh) + + +def _update_state(state, output, original_input, modulated_inp): + """Update cache state after full computation.""" + state.previous_residual = output - original_input + state.previous_modulated_input = modulated_inp + state.cnt += 1 + + +def _update_state_dual(state, hs_output, enc_output, hs_original, enc_original, modulated_inp): + """Update cache state after full computation (dual residual for CogVideoX).""" + state.previous_residual = hs_output - hs_original + state.previous_residual_encoder = enc_output - enc_original + state.previous_modulated_input = modulated_inp + state.cnt += 1 + + +def _apply_cached_residual(state, input_tensor, modulated_inp): + """Apply cached residual (fast path).""" + output = input_tensor + state.previous_residual + state.previous_modulated_input = modulated_inp + state.cnt += 1 + return output + + +def _apply_cached_residual_dual(state, hs, enc, modulated_inp): + """Apply cached residuals (fast path for CogVideoX).""" + hs_out = hs + state.previous_residual + enc_out = enc + state.previous_residual_encoder + state.previous_modulated_input = modulated_inp + state.cnt += 1 + return hs_out, enc_out + + +@dataclass +class TeaCacheConfig: + r""" + Configuration for [TeaCache](https://huggingface.co/papers/2411.19108). + + TeaCache (Timestep Embedding Aware Cache) speeds up diffusion model inference by reusing transformer block + computations when consecutive timestep embeddings are similar. It uses polynomial rescaling of L1 distances between + modulated inputs to decide when to cache. + + Currently supports: FLUX, FLUX-Kontext, Mochi, Lumina2, and CogVideoX models. Model type is auto-detected. + + Attributes: + rel_l1_thresh (`float`, defaults to `0.2`): + Threshold for accumulated relative L1 distance. When below this threshold, the cached residual is reused. + Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, 0.6 for ~2.0x. Mochi models require lower + thresholds (0.06-0.09). + coefficients (`List[float]`, *optional*): + Polynomial coefficients for rescaling L1 distance. Auto-detected based on model type if not provided. + Evaluated as: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]`. + current_timestep_callback (`Callable[[], int]`, *optional*): + Callback returning current timestep. Used for debugging/statistics. + num_inference_steps (`int`, *optional*): + Total inference steps. Ensures first/last timesteps are always computed. Auto-detected if not provided. + num_inference_steps_callback (`Callable[[], int]`, *optional*): + Callback returning total inference steps. Alternative to `num_inference_steps`. + + Example: + ```python + >>> from diffusers import FluxPipeline + >>> from diffusers.hooks import TeaCacheConfig + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = TeaCacheConfig(rel_l1_thresh=0.2) + >>> pipe.transformer.enable_cache(config) + + >>> image = pipe("A cat sitting on a windowsill", num_inference_steps=4).images[0] + >>> pipe.transformer.disable_cache() + ``` + """ + + rel_l1_thresh: float = 0.2 + coefficients: Optional[List[float]] = None + current_timestep_callback: Optional[Callable[[], int]] = None + num_inference_steps: Optional[int] = None + num_inference_steps_callback: Optional[Callable[[], int]] = None + + def __post_init__(self): + self._validate_threshold() + self._validate_coefficients() + + def _validate_threshold(self): + """Validate rel_l1_thresh parameter.""" + if not isinstance(self.rel_l1_thresh, (int, float)): + raise TypeError( + f"rel_l1_thresh must be a number, got {type(self.rel_l1_thresh).__name__}. " + f"Please provide a float value between 0.1 and 1.0." + ) + if self.rel_l1_thresh <= 0: + raise ValueError( + f"rel_l1_thresh must be positive, got {self.rel_l1_thresh}. " + f"Based on the TeaCache paper, values between 0.1 and 0.3 work best. " + f"Try 0.25 for 1.5x speedup or 0.6 for 2x speedup." + ) + if self.rel_l1_thresh < 0.05: + logger.warning( + f"rel_l1_thresh={self.rel_l1_thresh} is very low and may result in minimal caching. " + f"Consider using values between 0.1 and 0.3 for optimal performance." + ) + if self.rel_l1_thresh > 1.0: + logger.warning( + f"rel_l1_thresh={self.rel_l1_thresh} is very high and may cause quality degradation. " + f"Consider using values between 0.1 and 0.6 for better quality-speed tradeoff." + ) + + def _validate_coefficients(self): + """Validate coefficients parameter if provided.""" + if self.coefficients is None: + return + if not isinstance(self.coefficients, (list, tuple)): + raise TypeError( + f"coefficients must be a list or tuple, got {type(self.coefficients).__name__}. " + f"Please provide a list of 5 polynomial coefficients." + ) + if len(self.coefficients) != 5: + raise ValueError( + f"coefficients must contain exactly 5 elements for 4th-degree polynomial, " + f"got {len(self.coefficients)}. The polynomial is evaluated as: " + f"c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]" + ) + if not all(isinstance(c, (int, float)) for c in self.coefficients): + raise TypeError( + f"All coefficients must be numbers. Got types: {[type(c).__name__ for c in self.coefficients]}" + ) + + def __repr__(self) -> str: + return ( + f"TeaCacheConfig(\n" + f" rel_l1_thresh={self.rel_l1_thresh},\n" + f" coefficients={self.coefficients},\n" + f" current_timestep_callback={self.current_timestep_callback},\n" + f" num_inference_steps={self.num_inference_steps},\n" + f" num_inference_steps_callback={self.num_inference_steps_callback}\n" + f")" + ) + + +class TeaCacheState(BaseState): + r""" + State for [TeaCache](https://huggingface.co/papers/2411.19108). + + Tracks caching state across diffusion timesteps, including counters, accumulated distances, and cached residuals. + """ + + def __init__(self): + self.cnt = 0 + self.num_steps = 0 + self.accumulated_rel_l1_distance = 0.0 + self.previous_modulated_input = None + self.previous_residual = None + # CogVideoX-specific: dual residual caching (encoder + hidden_states) + self.previous_residual_encoder = None # Only used by CogVideoX + # Lumina2-specific: per-sequence-length caching for variable sequence lengths + # Other models don't use these fields but they're allocated for simplicity + self.cache_dict = {} # Only used by Lumina2 + self.uncond_seq_len = None # Only used by Lumina2 + + def reset(self): + """Reset all state variables to initial values for a new inference run.""" + self.cnt = 0 + self.num_steps = 0 + self.accumulated_rel_l1_distance = 0.0 + self.previous_modulated_input = None + self.previous_residual = None + self.previous_residual_encoder = None + self.cache_dict = {} + self.uncond_seq_len = None + + def __repr__(self) -> str: + return ( + f"TeaCacheState(\n" + f" cnt={self.cnt},\n" + f" num_steps={self.num_steps},\n" + f" accumulated_rel_l1_distance={self.accumulated_rel_l1_distance:.6f},\n" + f" previous_modulated_input={'cached' if self.previous_modulated_input is not None else 'None'},\n" + f" previous_residual={'cached' if self.previous_residual is not None else 'None'}\n" + f")" + ) + + +class TeaCacheHook(ModelHook): + r""" + Hook implementing [TeaCache](https://huggingface.co/papers/2411.19108) for transformer models. + + Intercepts transformer forward pass and implements adaptive caching based on timestep embedding similarity. First + and last timesteps are always computed fully (never cached) to ensure maximum quality. + """ + + _is_stateful = True + + def __init__(self, config: TeaCacheConfig): + super().__init__() + self.config = config + # Coefficients will be set in initialize_hook() via auto-detection or user config + self.coefficients: Optional[List[float]] = config.coefficients + self.state_manager = StateManager(TeaCacheState, (), {}) + self.model_type: Optional[str] = None + self._forward_func: Optional[Callable] = None + + def _maybe_reset_state_for_new_inference( + self, state: TeaCacheState, module: torch.nn.Module, reset_encoder_residual: bool = False + ) -> None: + """Reset state if inference run completed. Initialize num_steps on first timestep if not set.""" + # Reset if we've completed all steps (new inference run) + if state.cnt == state.num_steps and state.num_steps > 0: + logger.debug("TeaCache: Inference run completed, resetting state") + state.cnt = 0 + state.accumulated_rel_l1_distance = 0.0 + state.previous_modulated_input = None + state.previous_residual = None + if reset_encoder_residual: + state.previous_residual_encoder = None + + # Set num_steps on first timestep (priority: config > callback > module attribute) + if state.cnt == 0 and state.num_steps == 0: + if self.config.num_inference_steps is not None: + state.num_steps = self.config.num_inference_steps + elif self.config.num_inference_steps_callback is not None: + state.num_steps = self.config.num_inference_steps_callback() + elif hasattr(module, "num_steps"): + state.num_steps = module.num_steps + + if state.num_steps > 0: + logger.debug(f"TeaCache: Using {state.num_steps} inference steps") + + def initialize_hook(self, module): + # Context is set by pipeline's cache_context() calls in the denoising loop. + # This enables proper state isolation between cond/uncond branches. + # See PR #12652 for discussion on this design decision. + model_config = _get_model_config() + + # Auto-detect model type and get forward function + self.model_type = _auto_detect_model_type(module) + self._forward_func = model_config[self.model_type]["forward_func"] + + # Validate model has required attributes for TeaCache + if self.model_type in ("Flux", "FluxKontext", "Mochi"): + if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0: + raise ValueError(f"TeaCache: {self.model_type} model missing transformer_blocks") + if not hasattr(module.transformer_blocks[0], "norm1"): + raise ValueError(f"TeaCache: {self.model_type} transformer_blocks[0] missing norm1") + elif self.model_type == "Lumina2": + if not hasattr(module, "layers") or len(module.layers) == 0: + raise ValueError("TeaCache: Lumina2 model missing layers") + elif "CogVideoX" in self.model_type: + if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0: + raise ValueError(f"TeaCache: {self.model_type} model missing transformer_blocks") + + # Auto-detect coefficients if not provided by user + if self.config.coefficients is None: + self.coefficients = model_config[self.model_type]["coefficients"] + logger.debug(f"TeaCache: Using {self.model_type} coefficients") + else: + self.coefficients = self.config.coefficients + logger.debug("TeaCache: Using user-provided coefficients") + + return module + + def new_forward(self, module, *args, **kwargs): + return self._forward_func(self, module, *args, **kwargs) + + def reset_state(self, module): + self.state_manager.reset() + return module + + +def _flux_teacache_forward( + hook: "TeaCacheHook", + module: torch.nn.Module, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + pooled_projections: torch.Tensor, + encoder_hidden_states: torch.Tensor, + txt_ids: torch.Tensor, + img_ids: torch.Tensor, + controlnet_block_samples: Optional[List[torch.Tensor]] = None, + controlnet_single_block_samples: Optional[List[torch.Tensor]] = None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + **kwargs, +): + """TeaCache forward for Flux models with ControlNet support.""" + args, extra_kwargs = _handle_accelerate_hook( + module, + hidden_states, + timestep, + pooled_projections, + encoder_hidden_states, + txt_ids, + img_ids, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + return_dict=return_dict, + controlnet_blocks_repeat=controlnet_blocks_repeat, + **kwargs, + ) + hidden_states, timestep, pooled_projections, encoder_hidden_states, txt_ids, img_ids = args + controlnet_block_samples = extra_kwargs.pop("controlnet_block_samples", controlnet_block_samples) + controlnet_single_block_samples = extra_kwargs.pop( + "controlnet_single_block_samples", controlnet_single_block_samples + ) + controlnet_blocks_repeat = extra_kwargs.pop("controlnet_blocks_repeat", controlnet_blocks_repeat) + return_dict = extra_kwargs.pop("return_dict", return_dict) + kwargs = extra_kwargs + + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module) + + hidden_states = module.x_embedder(hidden_states) + + timestep_scaled = timestep.to(hidden_states.dtype) * 1000 + if kwargs.get("guidance") is not None: + guidance = kwargs["guidance"].to(hidden_states.dtype) * 1000 + temb = module.time_text_embed(timestep_scaled, guidance, pooled_projections) + else: + temb = module.time_text_embed(timestep_scaled, pooled_projections) + + modulated_inp = module.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] + + if _should_compute(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh): + ori_hs = hidden_states.clone() + enc = module.context_embedder(encoder_hidden_states) + + txt = txt_ids[0] if txt_ids.ndim == 3 else txt_ids + img = img_ids[0] if img_ids.ndim == 3 else img_ids + ids = torch.cat((txt, img), dim=0) + image_rotary_emb = module.pos_embed(ids) + + joint_attention_kwargs = kwargs.get("joint_attention_kwargs") + for index_block, block in enumerate(module.transformer_blocks): + enc, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=enc, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + # ControlNet residual + if controlnet_block_samples is not None: + interval_control = len(module.transformer_blocks) / len(controlnet_block_samples) + interval_control = ( + int(interval_control) if interval_control == int(interval_control) else int(interval_control) + 1 + ) + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + for index_block, block in enumerate(module.single_transformer_blocks): + enc, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=enc, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + # ControlNet residual + if controlnet_single_block_samples is not None: + interval_control = len(module.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = ( + int(interval_control) if interval_control == int(interval_control) else int(interval_control) + 1 + ) + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] + + _update_state(state, hidden_states, ori_hs, modulated_inp) + else: + hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp) + + hidden_states = module.norm_out(hidden_states, temb) + output = module.proj_out(hidden_states) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def _mochi_teacache_forward( + hook: "TeaCacheHook", + module: torch.nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_attention_mask: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +): + """TeaCache forward for Mochi models.""" + args, kwargs = _handle_accelerate_hook( + module, + hidden_states, + encoder_hidden_states, + timestep, + encoder_attention_mask, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + ) + hidden_states, encoder_hidden_states, timestep, encoder_attention_mask = args + attention_kwargs = kwargs.get("attention_kwargs", attention_kwargs) + return_dict = kwargs.get("return_dict", return_dict) + + attention_kwargs, lora_scale = _extract_lora_scale(attention_kwargs) + if USE_PEFT_BACKEND: + scale_lora_layers(module, lora_scale) + + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module) + + batch_size, _, num_frames, height, width = hidden_states.shape + p = module.config.patch_size + post_patch_height = height // p + post_patch_width = width // p + + temb, encoder_hidden_states = module.time_embed( + timestep, + encoder_hidden_states, + encoder_attention_mask, + hidden_dtype=hidden_states.dtype, + ) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = module.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + image_rotary_emb = module.rope( + module.pos_frequencies, + num_frames, + post_patch_height, + post_patch_width, + device=hidden_states.device, + dtype=torch.float32, + ) + + modulated_inp = module.transformer_blocks[0].norm1(hidden_states, temb)[0] + + if _should_compute(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh): + ori_hs = hidden_states.clone() + enc = encoder_hidden_states + for block in module.transformer_blocks: + if torch.is_grad_enabled() and module.gradient_checkpointing: + hidden_states, enc = module._gradient_checkpointing_func( + block, + hidden_states, + enc, + temb, + encoder_attention_mask, + image_rotary_emb, + ) + else: + hidden_states, enc = block( + hidden_states=hidden_states, + encoder_hidden_states=enc, + temb=temb, + encoder_attention_mask=encoder_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = module.norm_out(hidden_states, temb) + _update_state(state, hidden_states, ori_hs, modulated_inp) + else: + hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp) + + hidden_states = module.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) + hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if USE_PEFT_BACKEND: + unscale_lora_layers(module, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def _lumina2_teacache_forward( + hook: "TeaCacheHook", + module: torch.nn.Module, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +): + """TeaCache forward for Lumina2 models. + + Note: Lumina2 uses inline caching logic instead of `_should_compute()` because it requires per-sequence-length + caching for variable sequence lengths (CFG batches have different lengths). Each sequence length gets its own cache + entry in `state.cache_dict`. + + Note: Gradient checkpointing is not supported in this TeaCache implementation for Lumina2. + """ + args, kwargs = _handle_accelerate_hook( + module, + hidden_states, + timestep, + encoder_hidden_states, + encoder_attention_mask, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + ) + hidden_states, timestep, encoder_hidden_states, encoder_attention_mask = args + attention_kwargs = kwargs.get("attention_kwargs", attention_kwargs) + return_dict = kwargs.get("return_dict", return_dict) + + attention_kwargs, lora_scale = _extract_lora_scale(attention_kwargs) + if USE_PEFT_BACKEND: + scale_lora_layers(module, lora_scale) + + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module) + + batch_size, _, height, width = hidden_states.shape + + temb, encoder_hidden_states_processed = module.time_caption_embed(hidden_states, timestep, encoder_hidden_states) + ( + image_patch_embeddings, + context_rotary_emb, + noise_rotary_emb, + joint_rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = module.rope_embedder(hidden_states, encoder_attention_mask) + image_patch_embeddings = module.x_embedder(image_patch_embeddings) + + for layer in module.context_refiner: + encoder_hidden_states_processed = layer( + encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb + ) + for layer in module.noise_refiner: + image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb) + + max_seq_len = max(seq_lengths) + input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, module.config.hidden_size) + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len] + input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i] + + use_mask = len(set(seq_lengths)) > 1 + attention_mask_for_main_loop_arg = None + if use_mask: + mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + mask[i, :seq_len_val] = True + attention_mask_for_main_loop_arg = mask + + modulated_inp = module.layers[0].norm1(input_to_main_loop, temb)[0] + + # Per-sequence-length caching for variable sequence lengths + cache_key = max_seq_len + if cache_key not in state.cache_dict: + state.cache_dict[cache_key] = { + "previous_modulated_input": None, + "previous_residual": None, + "accumulated_rel_l1_distance": 0.0, + } + cache = state.cache_dict[cache_key] + + is_boundary_step = state.cnt == 0 or state.cnt == state.num_steps - 1 + has_previous = cache["previous_modulated_input"] is not None + + if is_boundary_step or not has_previous: + should_calc = True + cache["accumulated_rel_l1_distance"] = 0.0 + else: + rel_distance = _compute_rel_l1_distance(modulated_inp, cache["previous_modulated_input"]) + cache["accumulated_rel_l1_distance"] += _rescale_distance(hook.coefficients, rel_distance) + if cache["accumulated_rel_l1_distance"] >= hook.config.rel_l1_thresh: + should_calc = True + cache["accumulated_rel_l1_distance"] = 0.0 + else: + should_calc = False + + cache["previous_modulated_input"] = modulated_inp.clone() + + # Track sequence length for step counting (CFG) + if state.uncond_seq_len is None: + state.uncond_seq_len = cache_key + if cache_key != state.uncond_seq_len: + state.cnt += 1 + if state.cnt >= state.num_steps: + state.cnt = 0 + + # Apply cached residual or compute full forward + if not should_calc and cache["previous_residual"] is not None: + processed_hidden_states = input_to_main_loop + cache["previous_residual"] + else: + processed_hidden_states = input_to_main_loop + for layer in module.layers: + processed_hidden_states = layer( + processed_hidden_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb + ) + cache["previous_residual"] = processed_hidden_states - input_to_main_loop + + output_after_norm = module.norm_out(processed_hidden_states, temb) + p = module.config.patch_size + final_output_list = [] + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + image_part = output_after_norm[i][enc_len:seq_len_val] + h_p, w_p = height // p, width // p + reconstructed_image = ( + image_part.view(h_p, w_p, p, p, module.out_channels).permute(4, 0, 2, 1, 3).flatten(3, 4).flatten(1, 2) + ) + final_output_list.append(reconstructed_image) + final_output_tensor = torch.stack(final_output_list, dim=0) + + if USE_PEFT_BACKEND: + unscale_lora_layers(module, lora_scale) + + if not return_dict: + return (final_output_tensor,) + return Transformer2DModelOutput(sample=final_output_tensor) + + +def _cogvideox_teacache_forward( + hook: "TeaCacheHook", + module: torch.nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + ofs: Optional[Union[int, float, torch.LongTensor]] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +): + """TeaCache forward for CogVideoX models.""" + args, kwargs = _handle_accelerate_hook( + module, + hidden_states, + encoder_hidden_states, + timestep, + timestep_cond=timestep_cond, + ofs=ofs, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + ) + hidden_states, encoder_hidden_states, timestep = args + timestep_cond = kwargs.get("timestep_cond", timestep_cond) + ofs = kwargs.get("ofs", ofs) + image_rotary_emb = kwargs.get("image_rotary_emb", image_rotary_emb) + attention_kwargs = kwargs.get("attention_kwargs", attention_kwargs) + return_dict = kwargs.get("return_dict", return_dict) + + attention_kwargs, lora_scale = _extract_lora_scale(attention_kwargs) + if USE_PEFT_BACKEND: + scale_lora_layers(module, lora_scale) + + state = hook.state_manager.get_state() + hook._maybe_reset_state_for_new_inference(state, module, reset_encoder_residual=True) + + batch_size, num_frames, _, height, width = hidden_states.shape + + t_emb = module.time_proj(timestep) + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = module.time_embedding(t_emb, timestep_cond) + + if module.ofs_embedding is not None: + ofs_emb = module.ofs_proj(ofs) + ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) + ofs_emb = module.ofs_embedding(ofs_emb) + emb = emb + ofs_emb + + hs = module.patch_embed(encoder_hidden_states, hidden_states) + hs = module.embedding_dropout(hs) + + text_seq_length = encoder_hidden_states.shape[1] + enc = hs[:, :text_seq_length] + hs = hs[:, text_seq_length:] + + modulated_inp = emb + + if _should_compute_dual(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh): + ori_hs = hs.clone() + ori_enc = enc.clone() + for block in module.transformer_blocks: + if torch.is_grad_enabled() and module.gradient_checkpointing: + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hs, enc = torch.utils.checkpoint.checkpoint( + lambda *a: block(*a), + hs, + enc, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hs, enc = block( + hidden_states=hs, + encoder_hidden_states=enc, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + _update_state_dual(state, hs, enc, ori_hs, ori_enc, modulated_inp) + else: + hs, enc = _apply_cached_residual_dual(state, hs, enc, modulated_inp) + + if not module.config.use_rotary_positional_embeddings: + hs = module.norm_final(hs) + else: + hs_cat = torch.cat([enc, hs], dim=1) + hs_cat = module.norm_final(hs_cat) + hs = hs_cat[:, text_seq_length:] + + hs = module.norm_out(hs, temb=emb) + hs = module.proj_out(hs) + + p = module.config.patch_size + p_t = module.config.patch_size_t + if p_t is None: + output = hs.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hs.reshape(batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + + if USE_PEFT_BACKEND: + unscale_lora_layers(module, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def apply_teacache(module: torch.nn.Module, config: TeaCacheConfig) -> None: + r""" + Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given module. + + TeaCache speeds up diffusion model inference (1.5x-2x) by caching transformer block computations when consecutive + timestep embeddings are similar. Model type is auto-detected based on the module class name. + + Args: + module (`torch.nn.Module`): + The transformer model to optimize (e.g., FluxTransformer2DModel, CogVideoXTransformer3DModel). + config (`TeaCacheConfig`): + The configuration to use for TeaCache. + + Example: + ```python + >>> import torch + >>> from diffusers import FluxPipeline + >>> from diffusers.hooks import TeaCacheConfig, apply_teacache + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> apply_teacache(pipe.transformer, TeaCacheConfig(rel_l1_thresh=0.2)) + + >>> image = pipe("A cat on a windowsill", num_inference_steps=4).images[0] + ``` + """ + # Register hook on main transformer + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = TeaCacheHook(config) + registry.register_hook(hook, _TEACACHE_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 153608bb2bf8..26e8c9eb4840 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -28,6 +28,7 @@ class CacheMixin: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [FasterCache](https://huggingface.co/papers/2410.19355) - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) + - [TeaCache](https://huggingface.co/papers/2411.19108) """ _cache_config = None @@ -70,10 +71,12 @@ def enable_cache(self, config) -> None: FirstBlockCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, + TeaCacheConfig, apply_faster_cache, apply_first_block_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, + apply_teacache, ) if self.is_cache_enabled: @@ -87,6 +90,8 @@ def enable_cache(self, config) -> None: apply_first_block_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) + elif isinstance(config, TeaCacheConfig): + apply_teacache(self, config) elif isinstance(config, TaylorSeerCacheConfig): apply_taylorseer_cache(self, config) else: @@ -101,11 +106,13 @@ def disable_cache(self) -> None: HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, + TeaCacheConfig, ) from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK + from ..hooks.teacache import _TEACACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -120,6 +127,8 @@ def disable_cache(self) -> None: registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, TeaCacheConfig): + registry.remove_hook(_TEACACHE_HOOK, recurse=True) elif isinstance(self._cache_config, TaylorSeerCacheConfig): registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True) else: diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 77121edb9fc9..a08d33dbfa77 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -25,6 +25,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import LuminaFeedForward from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -322,7 +323,7 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths -class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" Lumina2NextDiT: Diffusion model with a Transformer backbone. diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 94c4c394465b..2ea7307fec32 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -84,7 +84,6 @@ >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image - >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py index d259f7ee7865..b41d9772a7cc 100644 --- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -53,7 +53,6 @@ >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> from diffusers import HiDreamImagePipeline - >>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( ... "meta-llama/Meta-Llama-3.1-8B-Instruct", diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index df5b3f5c10a5..5a6b8d5e9f37 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -85,7 +85,6 @@ >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image - >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 66d5ffa6b849..a1d0407caf5e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -459,7 +459,6 @@ def __call__( >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline >>> import torch - >>> pipeline = StableDiffusionPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... ) diff --git a/tests/hooks/test_teacache.py b/tests/hooks/test_teacache.py new file mode 100644 index 000000000000..90edfd733f89 --- /dev/null +++ b/tests/hooks/test_teacache.py @@ -0,0 +1,408 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import CogVideoXTransformer3DModel, Lumina2Transformer2DModel, MochiTransformer3DModel +from diffusers.hooks import HookRegistry, StateManager, TeaCacheConfig, apply_teacache +from diffusers.hooks.teacache import TeaCacheHook, TeaCacheState, _get_model_config, _should_compute + + +def _create_mochi_model() -> MochiTransformer3DModel: + return MochiTransformer3DModel( + patch_size=2, + num_attention_heads=2, + attention_head_dim=8, + num_layers=2, + in_channels=4, + text_embed_dim=16, + time_embed_dim=4, + ) + + +def _create_lumina2_model() -> Lumina2Transformer2DModel: + return Lumina2Transformer2DModel( + sample_size=16, + patch_size=2, + in_channels=4, + hidden_size=24, + num_layers=2, + num_refiner_layers=1, + num_attention_heads=3, + num_kv_heads=1, + ) + + +def _create_cogvideox_model() -> CogVideoXTransformer3DModel: + return CogVideoXTransformer3DModel( + num_attention_heads=2, + attention_head_dim=8, + in_channels=4, + num_layers=2, + text_embed_dim=16, + time_embed_dim=4, + ) + + +class TeaCacheConfigTests(unittest.TestCase): + """Tests for TeaCacheConfig parameter validation.""" + + def test_valid_config(self): + """Test valid configuration is accepted.""" + config = TeaCacheConfig(rel_l1_thresh=0.2) + self.assertEqual(config.rel_l1_thresh, 0.2) + # coefficients is None by default (auto-detected during hook initialization) + self.assertIsNone(config.coefficients) + + def test_invalid_type(self): + """Test invalid type for rel_l1_thresh raises TypeError.""" + with self.assertRaises(TypeError) as context: + TeaCacheConfig(rel_l1_thresh="invalid") + self.assertIn("must be a number", str(context.exception)) + + def test_negative_value(self): + """Test negative threshold raises ValueError.""" + with self.assertRaises(ValueError) as context: + TeaCacheConfig(rel_l1_thresh=-0.5) + self.assertIn("must be positive", str(context.exception)) + + def test_invalid_coefficients_length(self): + """Test wrong coefficient count raises ValueError.""" + with self.assertRaises(ValueError) as context: + TeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, 3.0]) + self.assertIn("exactly 5 elements", str(context.exception)) + + def test_invalid_coefficients_type(self): + """Test invalid coefficient types raise TypeError.""" + with self.assertRaises(TypeError) as context: + TeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, "invalid", 4.0, 5.0]) + self.assertIn("must be numbers", str(context.exception)) + + def test_very_low_threshold_accepted(self): + """Test very low threshold is accepted (with logging warning).""" + config = TeaCacheConfig(rel_l1_thresh=0.01) + self.assertEqual(config.rel_l1_thresh, 0.01) + + def test_very_high_threshold_accepted(self): + """Test very high threshold is accepted (with logging warning).""" + config = TeaCacheConfig(rel_l1_thresh=1.5) + self.assertEqual(config.rel_l1_thresh, 1.5) + + def test_config_repr(self): + """Test __repr__ method works correctly.""" + config = TeaCacheConfig(rel_l1_thresh=0.25) + repr_str = repr(config) + self.assertIn("TeaCacheConfig", repr_str) + self.assertIn("0.25", repr_str) + + def test_custom_coefficients(self): + """Test custom coefficients are accepted.""" + custom_coeffs = [1.0, 2.0, 3.0, 4.0, 5.0] + config = TeaCacheConfig(rel_l1_thresh=0.2, coefficients=custom_coeffs) + self.assertEqual(config.coefficients, custom_coeffs) + + +class TeaCacheStateTests(unittest.TestCase): + """Tests for TeaCacheState.""" + + def test_state_initialization(self): + """Test state initializes with correct default values.""" + state = TeaCacheState() + self.assertEqual(state.cnt, 0) + self.assertEqual(state.num_steps, 0) + self.assertEqual(state.accumulated_rel_l1_distance, 0.0) + self.assertIsNone(state.previous_modulated_input) + self.assertIsNone(state.previous_residual) + + def test_state_reset(self): + """Test state reset clears all values.""" + state = TeaCacheState() + state.cnt = 5 + state.num_steps = 10 + state.accumulated_rel_l1_distance = 0.5 + state.previous_modulated_input = torch.randn(1, 10) + state.previous_residual = torch.randn(1, 10) + + state.reset() + + self.assertEqual(state.cnt, 0) + self.assertEqual(state.num_steps, 0) + self.assertEqual(state.accumulated_rel_l1_distance, 0.0) + self.assertIsNone(state.previous_modulated_input) + self.assertIsNone(state.previous_residual) + + def test_state_repr(self): + """Test __repr__ method works correctly.""" + state = TeaCacheState() + state.cnt = 3 + state.num_steps = 10 + repr_str = repr(state) + self.assertIn("TeaCacheState", repr_str) + self.assertIn("cnt=3", repr_str) + self.assertIn("num_steps=10", repr_str) + + +class TeaCacheHookTests(unittest.TestCase): + """Tests for TeaCacheHook functionality.""" + + def test_hook_initialization(self): + """Test hook initializes correctly with config.""" + config = TeaCacheConfig(rel_l1_thresh=0.2) + hook = TeaCacheHook(config) + + self.assertEqual(hook.config.rel_l1_thresh, 0.2) + self.assertIsNone(hook.coefficients) + self.assertIsNotNone(hook.state_manager) + + def test_should_compute_logic(self): + """Test _should_compute decision logic.""" + coefficients = [1, 0, 0, 0, 0] + rel_l1_thresh = 1.0 + state = TeaCacheState() + + x0 = torch.ones(1, 4) + x1 = torch.ones(1, 4) * 1.1 + + self.assertTrue(_should_compute(state, x0, coefficients, rel_l1_thresh)) + + state.previous_modulated_input = x0 + state.previous_residual = torch.zeros(1, 4) + state.cnt = 1 + state.num_steps = 4 + + self.assertFalse(_should_compute(state, x1, coefficients, rel_l1_thresh)) + + state.cnt = state.num_steps - 1 + self.assertTrue(_should_compute(state, x1, coefficients, rel_l1_thresh)) + + def test_apply_teacache_unsupported_model_raises_error(self): + """Test that apply_teacache raises error for unsupported models.""" + from diffusers.models import CacheMixin + + class UnsupportedModule(torch.nn.Module, CacheMixin): + def __init__(self): + super().__init__() + self.dummy = torch.nn.Linear(4, 4) + + module = UnsupportedModule() + config = TeaCacheConfig(rel_l1_thresh=0.2) + + with self.assertRaises(ValueError) as context: + apply_teacache(module, config) + self.assertIn("Unsupported model", str(context.exception)) + self.assertIn("UnsupportedModule", str(context.exception)) + + +class TeaCacheMultiModelTests(unittest.TestCase): + """Tests for TeaCache multi-model support (Mochi, Lumina2, CogVideoX).""" + + def test_model_coefficient_registry(self): + """Test that model coefficients are properly registered.""" + model_config = _get_model_config() + + self.assertIn("Flux", model_config) + self.assertIn("Mochi", model_config) + self.assertIn("Lumina2", model_config) + self.assertIn("CogVideoX", model_config) + + for model_name, config in model_config.items(): + coeffs = config["coefficients"] + self.assertEqual(len(coeffs), 5, f"{model_name} coefficients should have 5 elements") + self.assertTrue( + all(isinstance(c, (int, float)) for c in coeffs), f"{model_name} coefficients should be numbers" + ) + + def test_mochi_extractor(self): + """Test Mochi modulated input extraction.""" + model = _create_mochi_model() + + hidden_states = torch.randn(2, 4, 2, 8, 8) + timestep = torch.randint(0, 1000, (2,)) + encoder_hidden_states = torch.randn(2, 16, 16) + encoder_attention_mask = torch.ones(2, 16).bool() + + temb, _ = model.time_embed( + timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype + ) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = model.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (2, -1)).flatten(1, 2) + + modulated_inp = model.transformer_blocks[0].norm1(hidden_states, temb)[0] + self.assertIsInstance(modulated_inp, torch.Tensor) + self.assertEqual(modulated_inp.shape[0], hidden_states.shape[0]) + + def test_lumina2_extractor(self): + """Test Lumina2 modulated input extraction.""" + model = _create_lumina2_model() + + batch_size = 2 + seq_len = 100 + hidden_size = model.config.hidden_size + + input_to_main_loop = torch.randn(batch_size, seq_len, hidden_size) + temb = torch.randn(batch_size, hidden_size) + + modulated_inp = model.layers[0].norm1(input_to_main_loop, temb)[0] + self.assertIsInstance(modulated_inp, torch.Tensor) + self.assertEqual(modulated_inp.shape[0], batch_size) + + def test_cogvideox_extractor(self): + """Test CogVideoX modulated input extraction.""" + model = _create_cogvideox_model() + + hidden_states = torch.randn(2, 2, 4, 8, 8) + timestep = torch.randint(0, 1000, (2,)) + + t_emb = model.time_proj(timestep) + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = model.time_embedding(t_emb, None) + + modulated_inp = emb + self.assertIsInstance(modulated_inp, torch.Tensor) + self.assertEqual(modulated_inp.shape, emb.shape) + + def test_auto_detect_mochi(self): + """Test auto-detection for Mochi models.""" + model = _create_mochi_model() + model_config = _get_model_config() + + config = TeaCacheConfig(rel_l1_thresh=0.2) + apply_teacache(model, config) + + registry = HookRegistry.check_if_exists_or_initialize(model) + hook = registry.get_hook("teacache") + self.assertIsNotNone(hook) + self.assertEqual(hook.coefficients, model_config["Mochi"]["coefficients"]) + + model.disable_cache() + + def test_auto_detect_lumina2(self): + """Test auto-detection for Lumina2 models.""" + model = _create_lumina2_model() + model_config = _get_model_config() + + config = TeaCacheConfig(rel_l1_thresh=0.2) + apply_teacache(model, config) + + registry = HookRegistry.check_if_exists_or_initialize(model) + hook = registry.get_hook("teacache") + self.assertIsNotNone(hook) + self.assertEqual(hook.coefficients, model_config["Lumina2"]["coefficients"]) + + registry.remove_hook("teacache") + + def test_auto_detect_cogvideox(self): + """Test auto-detection for CogVideoX models.""" + model = _create_cogvideox_model() + model_config = _get_model_config() + + config = TeaCacheConfig(rel_l1_thresh=0.2) + apply_teacache(model, config) + + registry = HookRegistry.check_if_exists_or_initialize(model) + hook = registry.get_hook("teacache") + self.assertIsNotNone(hook) + self.assertEqual(hook.coefficients, model_config["CogVideoX"]["coefficients"]) + + model.disable_cache() + + def test_teacache_state_encoder_residual(self): + """Test that TeaCacheState supports encoder residual for CogVideoX.""" + state = TeaCacheState() + self.assertIsNone(state.previous_residual_encoder) + + state.previous_residual_encoder = torch.randn(2, 10, 16) + self.assertIsNotNone(state.previous_residual_encoder) + + state.reset() + self.assertIsNone(state.previous_residual_encoder) + + def test_model_routing(self): + """Test that new_forward routes to correct handler based on model type.""" + config = TeaCacheConfig(rel_l1_thresh=0.2) + + mochi_hook = TeaCacheHook(config) + mochi_hook.initialize_hook(_create_mochi_model()) + self.assertEqual(mochi_hook.model_type, "Mochi") + + lumina_hook = TeaCacheHook(config) + lumina_hook.initialize_hook(_create_lumina2_model()) + self.assertEqual(lumina_hook.model_type, "Lumina2") + + cogvideox_hook = TeaCacheHook(config) + cogvideox_hook.initialize_hook(_create_cogvideox_model()) + self.assertEqual(cogvideox_hook.model_type, "CogVideoX") + + +class StateManagerContextTests(unittest.TestCase): + """Tests for StateManager context isolation and backward compatibility.""" + + def test_context_isolation(self): + """Test that different contexts maintain separate states.""" + state_manager = StateManager(TeaCacheState, (), {}) + + state_manager.set_context("cond") + cond_state = state_manager.get_state() + cond_state.cnt = 5 + cond_state.accumulated_rel_l1_distance = 0.3 + + state_manager.set_context("uncond") + uncond_state = state_manager.get_state() + uncond_state.cnt = 10 + uncond_state.accumulated_rel_l1_distance = 0.7 + + state_manager.set_context("cond") + self.assertEqual(state_manager.get_state().cnt, 5) + self.assertEqual(state_manager.get_state().accumulated_rel_l1_distance, 0.3) + + state_manager.set_context("uncond") + self.assertEqual(state_manager.get_state().cnt, 10) + self.assertEqual(state_manager.get_state().accumulated_rel_l1_distance, 0.7) + + def test_default_context_fallback(self): + """Test that state works without explicit context (backward compatibility).""" + state_manager = StateManager(TeaCacheState, (), {}) + + state = state_manager.get_state() + self.assertIsNotNone(state) + self.assertEqual(state.cnt, 0) + + state.cnt = 42 + + state2 = state_manager.get_state() + self.assertEqual(state2.cnt, 42) + + def test_default_context_separate_from_named(self): + """Test that default context is separate from named contexts.""" + state_manager = StateManager(TeaCacheState, (), {}) + + default_state = state_manager.get_state() + default_state.cnt = 100 + + state_manager.set_context("named") + named_state = state_manager.get_state() + named_state.cnt = 200 + + state_manager._current_context = None + self.assertEqual(state_manager.get_state().cnt, 100) + + state_manager.set_context("named") + self.assertEqual(state_manager.get_state().cnt, 200) + + +if __name__ == "__main__": + unittest.main()