From a4bfa451fe38496dfd2a48a18076d0baf12b0999 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Thu, 13 Nov 2025 15:06:36 +0700 Subject: [PATCH 01/29] init taylor_seer cache --- src/diffusers/hooks/taylorseer_cache.py | 118 ++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 src/diffusers/hooks/taylorseer_cache.py diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py new file mode 100644 index 000000000000..2f8f6a4b476e --- /dev/null +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -0,0 +1,118 @@ +# Experimental hook for TaylorSeer cache +# Supports Flux only for now + +import torch +from dataclasses import dataclass +from typing import Callable +from .hooks import ModelHook +import math +from ..models.attention import Attention +from ..models.attention import AttentionModuleMixin +from ._common import ( + _ATTENTION_CLASSES, +) +from ..hooks import HookRegistry + +_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" + +@dataclass +class TaylorSeerCacheConfig: + fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed + max_order: int = 1 # order of Taylor series expansion + current_timestep_callback: Callable[[], int] = None + +class TaylorSeerState: + def __init__(self): + self.predict_counter: int = 1 + self.last_step: int = 1000 + self.taylor_factors: dict[int, torch.Tensor] = {} + + def reset(self): + self.predict_counter = 1 + self.last_step = 1000 + self.taylor_factors = {} + + def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int): + N = math.abs(current_step - self.last_step) + # initialize the first order taylor factors + new_taylor_factors = {0: features} + for i in range(max_order): + if (self.taylor_factors.get(i) is not None) and current_step > 1: + new_taylor_factors[i+1] = (self.taylor_factors[i] - new_taylor_factors[i]) / N + else: + break + self.taylor_factors = new_taylor_factors + self.last_step = current_step + self.predict_counter = (self.predict_counter + 1) % refresh_threshold + + def predict(self, current_step: int, refresh_threshold: int): + k = current_step - self.last_step + device = self.taylor_factors[0].device + output = torch.zeros_like(self.taylor_factors[0], device=device) + for i in range(len(self.taylor_factors)): + output += self.taylor_factors[i] * (k ** i) / math.factorial(i) + self.predict_counter = (self.predict_counter + 1) % refresh_threshold + return output + +class TaylorSeerAttentionCacheHook(ModelHook): + _is_stateful = True + + def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int]): + super().__init__() + self.fresh_threshold = fresh_threshold + self.max_order = max_order + self.current_timestep_callback = current_timestep_callback + + def initialize_hook(self, module): + self.img_state = TaylorSeerState() + self.txt_state = TaylorSeerState() + self.ip_state = TaylorSeerState() + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + current_step = self.current_timestep_callback() + assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook" + should_predict = self.img_state.predict_counter > 0 + + if not should_predict: + attention_outputs = self.fn_ref.original_forward(*args, **kwargs) + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold) + self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold) + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold) + self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold) + self.ip_state.update(ip_attn_output, current_step, self.max_order, self.fresh_threshold) + else: + attn_output = self.img_state.predict(current_step, self.fresh_threshold) + context_attn_output = self.txt_state.predict(current_step, self.fresh_threshold) + ip_attn_output = self.ip_state.predict(current_step, self.fresh_threshold) + attention_outputs = (attn_output, context_attn_output, ip_attn_output) + return attention_outputs + + def reset_state(self, module: torch.nn.Module) -> None: + self.img_state.reset() + self.txt_state.reset() + self.ip_state.reset() + return module + +def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): + for name, submodule in module.named_modules(): + if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): + # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB + # cannot be applied to this layer. For custom layers, users can extend this functionality and implement + # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. + continue + _apply_taylorseer_cache_on_attention_class(name, submodule, config) + + +def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, config: TaylorSeerCacheConfig): + _apply_taylorseer_cache_hook(module, config) + + +def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig): + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback) + registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file From 8f495b607f1176d1dd11101c21bf12f35892f945 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Thu, 13 Nov 2025 11:37:54 +0000 Subject: [PATCH 02/29] make compatible with any tuple size returned --- src/diffusers/hooks/taylorseer_cache.py | 55 ++++++++++++------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 2f8f6a4b476e..b339ee1d6b9f 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -23,12 +23,12 @@ class TaylorSeerCacheConfig: class TaylorSeerState: def __init__(self): - self.predict_counter: int = 1 + self.predict_counter: int = 0 self.last_step: int = 1000 self.taylor_factors: dict[int, torch.Tensor] = {} def reset(self): - self.predict_counter = 1 + self.predict_counter = 0 self.last_step = 1000 self.taylor_factors = {} @@ -43,15 +43,15 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, refr break self.taylor_factors = new_taylor_factors self.last_step = current_step - self.predict_counter = (self.predict_counter + 1) % refresh_threshold + self.predict_counter = refresh_threshold - def predict(self, current_step: int, refresh_threshold: int): + def predict(self, current_step: int): k = current_step - self.last_step device = self.taylor_factors[0].device output = torch.zeros_like(self.taylor_factors[0], device=device) for i in range(len(self.taylor_factors)): output += self.taylor_factors[i] * (k ** i) / math.factorial(i) - self.predict_counter = (self.predict_counter + 1) % refresh_threshold + self.predict_counter -= 1 return output class TaylorSeerAttentionCacheHook(ModelHook): @@ -64,47 +64,44 @@ def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callba self.current_timestep_callback = current_timestep_callback def initialize_hook(self, module): - self.img_state = TaylorSeerState() - self.txt_state = TaylorSeerState() - self.ip_state = TaylorSeerState() + self.states = None + self.num_outputs = None return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): current_step = self.current_timestep_callback() assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook" - should_predict = self.img_state.predict_counter > 0 + + if self.states is None: + attention_outputs = self.fn_ref.original_forward(*args, **kwargs) + self.num_outputs = len(attention_outputs) + self.states = [TaylorSeerState() for _ in range(self.num_outputs)] + for i, feat in enumerate(attention_outputs): + self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) + return attention_outputs + + should_predict = self.states[0].predict_counter > 0 if not should_predict: attention_outputs = self.fn_ref.original_forward(*args, **kwargs) - if len(attention_outputs) == 2: - attn_output, context_attn_output = attention_outputs - self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold) - self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold) - elif len(attention_outputs) == 3: - attn_output, context_attn_output, ip_attn_output = attention_outputs - self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold) - self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold) - self.ip_state.update(ip_attn_output, current_step, self.max_order, self.fresh_threshold) - else: - attn_output = self.img_state.predict(current_step, self.fresh_threshold) - context_attn_output = self.txt_state.predict(current_step, self.fresh_threshold) - ip_attn_output = self.ip_state.predict(current_step, self.fresh_threshold) - attention_outputs = (attn_output, context_attn_output, ip_attn_output) + for i, feat in enumerate(attention_outputs): + self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) return attention_outputs + else: + predicted_outputs = [state.predict(current_step) for state in self.states] + return tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: - self.img_state.reset() - self.txt_state.reset() - self.ip_state.reset() + if self.states is not None: + for state in self.states: + state.reset() return module def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): for name, submodule in module.named_modules(): if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): - # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB - # cannot be applied to this layer. For custom layers, users can extend this functionality and implement - # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. continue + print(f"Applying TaylorSeer cache to {name}") _apply_taylorseer_cache_on_attention_class(name, submodule, config) From 8f8007261844069068ca70ba5d3497b66b1be526 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Thu, 13 Nov 2025 13:11:29 +0000 Subject: [PATCH 03/29] use logger for printing, add warmup feature --- src/diffusers/hooks/taylorseer_cache.py | 35 ++++++++++++++++++++----- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index b339ee1d6b9f..8c3c6a7c3614 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -12,11 +12,13 @@ _ATTENTION_CLASSES, ) from ..hooks import HookRegistry - +from ..utils import logging +logger = logging.get_logger(__name__) # pylint: disable=invalid-name _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" @dataclass class TaylorSeerCacheConfig: + warmup_steps: int = 3 # full compute some first steps fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed max_order: int = 1 # order of Taylor series expansion current_timestep_callback: Callable[[], int] = None @@ -33,7 +35,9 @@ def reset(self): self.taylor_factors = {} def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int): - N = math.abs(current_step - self.last_step) + logger.debug("="*10) + N = self.last_step - current_step + logger.debug(f"update: N: {N}, current_step: {current_step}, last_step: {self.last_step}") # initialize the first order taylor factors new_taylor_factors = {0: features} for i in range(max_order): @@ -44,6 +48,9 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, refr self.taylor_factors = new_taylor_factors self.last_step = current_step self.predict_counter = refresh_threshold + logger.debug(f"last_step: {self.last_step}") + logger.debug(f"predict_counter: {self.predict_counter}") + logger.debug("="*10) def predict(self, current_step: int): k = current_step - self.last_step @@ -52,20 +59,24 @@ def predict(self, current_step: int): for i in range(len(self.taylor_factors)): output += self.taylor_factors[i] * (k ** i) / math.factorial(i) self.predict_counter -= 1 + logger.debug(f"predict_counter: {self.predict_counter}") + logger.debug(f"k: {k}") return output class TaylorSeerAttentionCacheHook(ModelHook): _is_stateful = True - def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int]): + def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int], warmup_steps: int): super().__init__() self.fresh_threshold = fresh_threshold self.max_order = max_order self.current_timestep_callback = current_timestep_callback + self.warmup_steps = warmup_steps def initialize_hook(self, module): self.states = None self.num_outputs = None + self.warmup_steps_counter = 0 return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -74,21 +85,31 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.states is None: attention_outputs = self.fn_ref.original_forward(*args, **kwargs) + if self.warmup_steps_counter < self.warmup_steps: + logger.debug(f"warmup_steps_counter: {self.warmup_steps_counter}") + self.warmup_steps_counter += 1 + return attention_outputs + if isinstance(attention_outputs, torch.Tensor): + attention_outputs = [attention_outputs] self.num_outputs = len(attention_outputs) self.states = [TaylorSeerState() for _ in range(self.num_outputs)] for i, feat in enumerate(attention_outputs): self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) - return attention_outputs + return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs should_predict = self.states[0].predict_counter > 0 if not should_predict: attention_outputs = self.fn_ref.original_forward(*args, **kwargs) + if isinstance(attention_outputs, torch.Tensor): + attention_outputs = [attention_outputs] for i, feat in enumerate(attention_outputs): self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) - return attention_outputs + return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs else: predicted_outputs = [state.predict(current_step) for state in self.states] + if len(predicted_outputs) == 1: + return predicted_outputs[0] return tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: @@ -101,7 +122,7 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi for name, submodule in module.named_modules(): if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): continue - print(f"Applying TaylorSeer cache to {name}") + logger.debug(f"Applying TaylorSeer cache to {name}") _apply_taylorseer_cache_on_attention_class(name, submodule, config) @@ -111,5 +132,5 @@ def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, con def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig): registry = HookRegistry.check_if_exists_or_initialize(module) - hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback) + hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback, config.warmup_steps) registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file From 0602044da71913832c2e81350d76f0327567efa2 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Thu, 13 Nov 2025 17:03:35 +0000 Subject: [PATCH 04/29] still update in warmup steps --- src/diffusers/hooks/taylorseer_cache.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 8c3c6a7c3614..6c99f095e26f 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -85,10 +85,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.states is None: attention_outputs = self.fn_ref.original_forward(*args, **kwargs) - if self.warmup_steps_counter < self.warmup_steps: - logger.debug(f"warmup_steps_counter: {self.warmup_steps_counter}") - self.warmup_steps_counter += 1 - return attention_outputs if isinstance(attention_outputs, torch.Tensor): attention_outputs = [attention_outputs] self.num_outputs = len(attention_outputs) @@ -97,7 +93,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs - should_predict = self.states[0].predict_counter > 0 + should_predict = self.states[0].predict_counter > 0 and self.warmup_steps_counter > self.warmup_steps if not should_predict: attention_outputs = self.fn_ref.original_forward(*args, **kwargs) @@ -108,9 +104,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs else: predicted_outputs = [state.predict(current_step) for state in self.states] - if len(predicted_outputs) == 1: - return predicted_outputs[0] - return tuple(predicted_outputs) + return predicted_outputs[0] if len(predicted_outputs) == 1 else predicted_outputs def reset_state(self, module: torch.nn.Module) -> None: if self.states is not None: From 1099e493e635526c8ecbc8ebca0f57e4bea2a0d8 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 14 Nov 2025 07:00:12 +0000 Subject: [PATCH 05/29] refractor, add docs --- src/diffusers/__init__.py | 4 + src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/taylorseer_cache.py | 246 +++++++++++++++++------- src/diffusers/models/cache_utils.py | 9 +- 4 files changed, 185 insertions(+), 75 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 02df34c07e8e..69d4aa4ba345 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -169,10 +169,12 @@ "LayerSkipConfig", "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", + "TaylorSeerCacheConfig", "apply_faster_cache", "apply_first_block_cache", "apply_layer_skip", "apply_pyramid_attention_broadcast", + "apply_taylorseer_cache", ] ) _import_structure["models"].extend( @@ -883,10 +885,12 @@ LayerSkipConfig, PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, + TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, apply_layer_skip, apply_pyramid_attention_broadcast, + apply_taylorseer_cache, ) from .models import ( AllegroTransformer3DModel, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 524a92ea9966..1d9d43d96b2a 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -25,3 +25,4 @@ from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig + from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache \ No newline at end of file diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 6c99f095e26f..509f6ba1179d 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -1,9 +1,6 @@ -# Experimental hook for TaylorSeer cache -# Supports Flux only for now - import torch from dataclasses import dataclass -from typing import Callable +from typing import Callable, Optional, List, Dict from .hooks import ModelHook import math from ..models.attention import Attention @@ -13,118 +10,219 @@ ) from ..hooks import HookRegistry from ..utils import logging + logger = logging.get_logger(__name__) # pylint: disable=invalid-name _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" @dataclass class TaylorSeerCacheConfig: - warmup_steps: int = 3 # full compute some first steps - fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed - max_order: int = 1 # order of Taylor series expansion - current_timestep_callback: Callable[[], int] = None - -class TaylorSeerState: - def __init__(self): - self.predict_counter: int = 0 - self.last_step: int = 1000 - self.taylor_factors: dict[int, torch.Tensor] = {} + """ + Configuration for TaylorSeer cache. + See: https://huggingface.co/papers/2503.06923 + + Attributes: + warmup_steps (int, defaults to 3): Number of warmup steps without caching. + predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps. + max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features. + taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors. + """ + warmup_steps: int = 3 + predict_steps: int = 5 + max_order: int = 1 + taylor_factors_dtype: torch.dtype = torch.float32 + + def __repr__(self) -> str: + return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})" + +class TaylorSeerOutputState: + """ + Manages the state for Taylor series-based prediction of a single attention output. + Tracks Taylor expansion factors, last update step, and remaining prediction steps. + The Taylor expansion uses the timestep as the independent variable for approximation. + """ + + def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype): + self.module_name = module_name + self.remaining_predictions: int = 0 + self.last_update_step: Optional[int] = None + self.taylor_factors: Dict[int, torch.Tensor] = {} + self.taylor_factors_dtype = taylor_factors_dtype + self.module_dtype = module_dtype def reset(self): - self.predict_counter = 0 - self.last_step = 1000 + self.remaining_predictions = 0 + self.last_update_step = None self.taylor_factors = {} - def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int): - logger.debug("="*10) - N = self.last_step - current_step - logger.debug(f"update: N: {N}, current_step: {current_step}, last_step: {self.last_step}") - # initialize the first order taylor factors - new_taylor_factors = {0: features} - for i in range(max_order): - if (self.taylor_factors.get(i) is not None) and current_step > 1: - new_taylor_factors[i+1] = (self.taylor_factors[i] - new_taylor_factors[i]) / N - else: - break - self.taylor_factors = new_taylor_factors - self.last_step = current_step - self.predict_counter = refresh_threshold - logger.debug(f"last_step: {self.last_step}") - logger.debug(f"predict_counter: {self.predict_counter}") - logger.debug("="*10) - - def predict(self, current_step: int): - k = current_step - self.last_step + def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool): + """ + Updates the Taylor factors based on the current features and timestep. + Computes finite difference approximations for derivatives using recursive divided differences. + + Args: + features (torch.Tensor): The attention output features to update with. + current_step (int): The current timestep or step number from the diffusion model. + max_order (int): Maximum order of the Taylor expansion. + predict_steps (int): Number of prediction steps to set after update. + is_first_update (bool): Whether this is the initial update (skips difference computation). + """ + features = features.to(self.taylor_factors_dtype) + new_factors = {0: features} + if not is_first_update: + if self.last_update_step is None: + raise ValueError("Cannot update without prior initialization.") + delta_step = current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for updates.") + for i in range(max_order): + if i in self.taylor_factors: + # Finite difference: (current - previous) / delta for forward approximation + new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step + + # taylor factors will be kept in the taylor_factors_dtype + self.taylor_factors = new_factors + self.last_update_step = current_step + self.remaining_predictions = predict_steps + + def predict(self, current_step: int) -> torch.Tensor: + """ + Predicts the features using the Taylor series expansion at the given timestep. + + Args: + current_step (int): The current timestep for prediction. + + Returns: + torch.Tensor: The predicted features in the module's dtype. + """ + if self.last_update_step is None: + raise ValueError("Cannot predict without prior update.") + step_offset = current_step - self.last_update_step device = self.taylor_factors[0].device - output = torch.zeros_like(self.taylor_factors[0], device=device) - for i in range(len(self.taylor_factors)): - output += self.taylor_factors[i] * (k ** i) / math.factorial(i) - self.predict_counter -= 1 - logger.debug(f"predict_counter: {self.predict_counter}") - logger.debug(f"k: {k}") - return output + output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype) + for order in range(len(self.taylor_factors)): + output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order) + self.remaining_predictions -= 1 + # output will be converted to the module's dtype + return output.to(self.module_dtype) class TaylorSeerAttentionCacheHook(ModelHook): + """ + Hook for caching and predicting attention outputs using Taylor series approximations. + Applies to attention modules in diffusion models (e.g., Flux). + Performs full computations during warmup, then alternates between predictions and refreshes. + """ _is_stateful = True - def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int], warmup_steps: int): + def __init__( + self, + module_name: str, + predict_steps: int, + max_order: int, + warmup_steps: int, + taylor_factors_dtype: torch.dtype, + module_dtype: torch.dtype = None, + ): super().__init__() - self.fresh_threshold = fresh_threshold + self.module_name = module_name + self.predict_steps = predict_steps self.max_order = max_order - self.current_timestep_callback = current_timestep_callback self.warmup_steps = warmup_steps - - def initialize_hook(self, module): + self.step_counter = -1 + self.states: Optional[List[TaylorSeerOutputState]] = None + self.num_outputs: Optional[int] = None + self.taylor_factors_dtype = taylor_factors_dtype + self.module_dtype = module_dtype + + def initialize_hook(self, module: torch.nn.Module): + self.step_counter = -1 self.states = None self.num_outputs = None - self.warmup_steps_counter = 0 + self.module_dtype = None return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - current_step = self.current_timestep_callback() - assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook" + self.step_counter += 1 + is_warmup_phase = self.step_counter < self.warmup_steps if self.states is None: + # First step: always full compute and initialize attention_outputs = self.fn_ref.original_forward(*args, **kwargs) if isinstance(attention_outputs, torch.Tensor): attention_outputs = [attention_outputs] + else: + attention_outputs = list(attention_outputs) + module_dtype = attention_outputs[0].dtype self.num_outputs = len(attention_outputs) - self.states = [TaylorSeerState() for _ in range(self.num_outputs)] - for i, feat in enumerate(attention_outputs): - self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) - return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs - - should_predict = self.states[0].predict_counter > 0 and self.warmup_steps_counter > self.warmup_steps - - if not should_predict: + self.states = [ + TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype) + for _ in range(self.num_outputs) + ] + for i, features in enumerate(attention_outputs): + self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True) + return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) + + should_predict = self.states[0].remaining_predictions > 0 + if is_warmup_phase or not should_predict: + # Full compute during warmup or when refresh needed attention_outputs = self.fn_ref.original_forward(*args, **kwargs) if isinstance(attention_outputs, torch.Tensor): attention_outputs = [attention_outputs] - for i, feat in enumerate(attention_outputs): - self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) - return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs + else: + attention_outputs = list(attention_outputs) + is_first_update = self.step_counter == 0 # Only True for the very first step + for i, features in enumerate(attention_outputs): + self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update) + return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) else: - predicted_outputs = [state.predict(current_step) for state in self.states] - return predicted_outputs[0] if len(predicted_outputs) == 1 else predicted_outputs + # Predict using Taylor series + predicted_outputs = [state.predict(self.step_counter) for state in self.states] + return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: if self.states is not None: for state in self.states: state.reset() - return module def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): - for name, submodule in module.named_modules(): - if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): - continue - logger.debug(f"Applying TaylorSeer cache to {name}") - _apply_taylorseer_cache_on_attention_class(name, submodule, config) + """ + Applies the TaylorSeer cache to given pipeline. + Args: + module (torch.nn.Module): The model to apply the hook to. + config (TaylorSeerCacheConfig): Configuration for the cache. -def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, config: TaylorSeerCacheConfig): - _apply_taylorseer_cache_hook(module, config) + Example: + ```python + >>> import torch + >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig, apply_taylorseer_cache + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") -def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig): + >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32) + >>> apply_taylorseer_cache(pipe.transformer, config) + ``` + """ + for name, submodule in module.named_modules(): + if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): + logger.debug(f"Applying TaylorSeer cache to {name}") + _apply_taylorseer_cache_hook(name, submodule, config) + +def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig): + """ + Registers the TaylorSeer hook on the specified attention module. + + Args: + name (str): Name of the module. + module (Attention): The attention module. + config (TaylorSeerCacheConfig): Configuration for the cache. + """ registry = HookRegistry.check_if_exists_or_initialize(module) - hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback, config.warmup_steps) + hook = TaylorSeerAttentionCacheHook( + name, + config.predict_steps, + config.max_order, + config.warmup_steps, + config.taylor_factors_dtype, + ) registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 605c0d588c8c..ffbf296ff617 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -67,9 +67,11 @@ def enable_cache(self, config) -> None: FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig, + TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, apply_pyramid_attention_broadcast, + apply_taylorseer_cache, ) if self.is_cache_enabled: @@ -83,16 +85,19 @@ def enable_cache(self, config) -> None: apply_first_block_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) + elif isinstance(config, TaylorSeerCacheConfig): + apply_taylorseer_cache(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK + from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -107,6 +112,8 @@ def disable_cache(self) -> None: registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, TaylorSeerCacheConfig): + registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") From 7b4ad2de63c314489b8129a496bea5c67e31cf7e Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 14 Nov 2025 09:09:46 +0000 Subject: [PATCH 06/29] add configurable cache, skip compute module --- src/diffusers/hooks/taylorseer_cache.py | 169 ++++++++++++++++++------ 1 file changed, 126 insertions(+), 43 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 509f6ba1179d..89d6da307488 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -10,10 +10,28 @@ ) from ..hooks import HookRegistry from ..utils import logging - +import re +from collections import defaultdict logger = logging.get_logger(__name__) # pylint: disable=invalid-name _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" +SPECIAL_CACHE_IDENTIFIERS = { + "flux": [ + r"transformer_blocks\.\d+\.attn", + r"transformer_blocks\.\d+\.ff", + r"transformer_blocks\.\d+\.ff_context", + r"single_transformer_blocks\.\d+\.proj_out", + ] +} +SKIP_COMPUTE_IDENTIFIERS = { + "flux": [ + r"single_transformer_blocks\.\d+\.attn", + r"single_transformer_blocks\.\d+\.proj_mlp", + r"single_transformer_blocks\.\d+\.act_mlp", + ] +} + + @dataclass class TaylorSeerCacheConfig: """ @@ -25,14 +43,22 @@ class TaylorSeerCacheConfig: predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps. max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features. taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors. + architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers. + skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation. + special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache. """ + warmup_steps: int = 3 predict_steps: int = 5 max_order: int = 1 taylor_factors_dtype: torch.dtype = torch.float32 + architecture: str | None = None + skip_compute_identifiers: List[str] = None + special_cache_identifiers: List[str] = None def __repr__(self) -> str: - return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})" + return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})" + class TaylorSeerOutputState: """ @@ -41,20 +67,31 @@ class TaylorSeerOutputState: The Taylor expansion uses the timestep as the independent variable for approximation. """ - def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype): + def __init__( + self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype, is_skip: bool = False + ): self.module_name = module_name self.remaining_predictions: int = 0 self.last_update_step: Optional[int] = None self.taylor_factors: Dict[int, torch.Tensor] = {} self.taylor_factors_dtype = taylor_factors_dtype self.module_dtype = module_dtype + self.is_skip = is_skip + self.dummy_shape: Optional[Tuple[int, ...]] = None + self.device: Optional[torch.device] = None + self.dummy_tensor: Optional[torch.Tensor] = None def reset(self): self.remaining_predictions = 0 self.last_update_step = None self.taylor_factors = {} + self.dummy_shape = None + self.device = None + self.dummy_tensor = None - def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool): + def update( + self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool + ): """ Updates the Taylor factors based on the current features and timestep. Computes finite difference approximations for derivatives using recursive divided differences. @@ -66,23 +103,33 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, pred predict_steps (int): Number of prediction steps to set after update. is_first_update (bool): Whether this is the initial update (skips difference computation). """ - features = features.to(self.taylor_factors_dtype) - new_factors = {0: features} - if not is_first_update: - if self.last_update_step is None: - raise ValueError("Cannot update without prior initialization.") - delta_step = current_step - self.last_update_step - if delta_step == 0: - raise ValueError("Delta step cannot be zero for updates.") - for i in range(max_order): - if i in self.taylor_factors: - # Finite difference: (current - previous) / delta for forward approximation - new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step - - # taylor factors will be kept in the taylor_factors_dtype - self.taylor_factors = new_factors - self.last_update_step = current_step - self.remaining_predictions = predict_steps + if self.is_skip: + self.dummy_shape = features.shape + self.device = features.device + self.taylor_factors = {} + self.last_update_step = current_step + self.remaining_predictions = predict_steps + else: + features = features.to(self.taylor_factors_dtype) + new_factors = {0: features} + if not is_first_update: + if self.last_update_step is None: + raise ValueError("Cannot update without prior initialization.") + delta_step = current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for updates.") + for i in range(max_order): + if i in self.taylor_factors: + new_factors[i + 1] = ( + new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype) + ) / delta_step + else: + break + + # taylor factors will be kept in the taylor_factors_dtype + self.taylor_factors = new_factors + self.last_update_step = current_step + self.remaining_predictions = predict_steps def predict(self, current_step: int) -> torch.Tensor: """ @@ -94,16 +141,22 @@ def predict(self, current_step: int) -> torch.Tensor: Returns: torch.Tensor: The predicted features in the module's dtype. """ - if self.last_update_step is None: - raise ValueError("Cannot predict without prior update.") - step_offset = current_step - self.last_update_step - device = self.taylor_factors[0].device - output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype) - for order in range(len(self.taylor_factors)): - output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order) - self.remaining_predictions -= 1 - # output will be converted to the module's dtype - return output.to(self.module_dtype) + if self.is_skip: + if self.dummy_shape is None or self.device is None: + raise ValueError("Cannot predict for skip module without prior update.") + self.remaining_predictions -= 1 + return torch.empty(self.dummy_shape, dtype=self.module_dtype, device=self.device) + else: + if self.last_update_step is None: + raise ValueError("Cannot predict without prior update.") + step_offset = current_step - self.last_update_step + output = 0 + for order in range(len(self.taylor_factors)): + output += self.taylor_factors[order] * (step_offset**order) * (1 / math.factorial(order)) + self.remaining_predictions -= 1 + # output will be converted to the module's dtype + return output.to(self.module_dtype) + class TaylorSeerAttentionCacheHook(ModelHook): """ @@ -111,6 +164,7 @@ class TaylorSeerAttentionCacheHook(ModelHook): Applies to attention modules in diffusion models (e.g., Flux). Performs full computations during warmup, then alternates between predictions and refreshes. """ + _is_stateful = True def __init__( @@ -120,7 +174,7 @@ def __init__( max_order: int, warmup_steps: int, taylor_factors_dtype: torch.dtype, - module_dtype: torch.dtype = None, + is_skip_compute: bool = False, ): super().__init__() self.module_name = module_name @@ -131,13 +185,12 @@ def __init__( self.states: Optional[List[TaylorSeerOutputState]] = None self.num_outputs: Optional[int] = None self.taylor_factors_dtype = taylor_factors_dtype - self.module_dtype = module_dtype + self.is_skip_compute = is_skip_compute def initialize_hook(self, module: torch.nn.Module): self.step_counter = -1 self.states = None self.num_outputs = None - self.module_dtype = None return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -154,11 +207,15 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): module_dtype = attention_outputs[0].dtype self.num_outputs = len(attention_outputs) self.states = [ - TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype) + TaylorSeerOutputState( + self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute + ) for _ in range(self.num_outputs) ] for i, features in enumerate(attention_outputs): - self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True) + self.states[i].update( + features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True + ) return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) should_predict = self.states[0].remaining_predictions > 0 @@ -179,9 +236,8 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: - if self.states is not None: - for state in self.states: - state.reset() + self.states = None + def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): """ @@ -199,30 +255,57 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32) + >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32, architecture="flux") >>> apply_taylorseer_cache(pipe.transformer, config) ``` """ + if config.skip_compute_identifiers: + skip_compute_identifiers = config.skip_compute_identifiers + else: + skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, []) + + if config.special_cache_identifiers: + special_cache_identifiers = config.special_cache_identifiers + else: + special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, []) + + logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}") + logger.debug(f"Special cache identifiers: {special_cache_identifiers}") + for name, submodule in module.named_modules(): - if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): + if skip_compute_identifiers and special_cache_identifiers: + if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any( + re.fullmatch(identifier, name) for identifier in special_cache_identifiers + ): + logger.debug(f"Applying TaylorSeer cache to {name}") + _apply_taylorseer_cache_hook(name, submodule, config) + elif isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): logger.debug(f"Applying TaylorSeer cache to {name}") _apply_taylorseer_cache_hook(name, submodule, config) + def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig): """ Registers the TaylorSeer hook on the specified attention module. - Args: name (str): Name of the module. module (Attention): The attention module. config (TaylorSeerCacheConfig): Configuration for the cache. """ + + is_skip_compute = any( + re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, []) + ) + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = TaylorSeerAttentionCacheHook( name, config.predict_steps, config.max_order, config.warmup_steps, config.taylor_factors_dtype, + is_skip_compute=is_skip_compute, ) - registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file + + registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) From 51b4318a3e5b2dc2b3df93f6e2fc2decc254a320 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Sat, 15 Nov 2025 05:13:33 +0000 Subject: [PATCH 07/29] allow special cache ids only --- src/diffusers/hooks/taylorseer_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 89d6da307488..3c5d0a2f3991 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -273,7 +273,7 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi logger.debug(f"Special cache identifiers: {special_cache_identifiers}") for name, submodule in module.named_modules(): - if skip_compute_identifiers and special_cache_identifiers: + if (skip_compute_identifiers and special_cache_identifiers) or (special_cache_identifiers): if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any( re.fullmatch(identifier, name) for identifier in special_cache_identifiers ): From 7238d40dd9859dbcec7ac7ca87c9b13f3aea3558 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Sun, 16 Nov 2025 05:09:44 +0000 Subject: [PATCH 08/29] add stop_predicts (cooldown) --- src/diffusers/hooks/taylorseer_cache.py | 126 +++++++++++++++--------- 1 file changed, 79 insertions(+), 47 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 3c5d0a2f3991..cb6b7fedd527 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -1,6 +1,6 @@ import torch from dataclasses import dataclass -from typing import Callable, Optional, List, Dict +from typing import Callable, Optional, List, Dict, Tuple from .hooks import ModelHook import math from ..models.attention import Attention @@ -12,23 +12,28 @@ from ..utils import logging import re from collections import defaultdict + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" -SPECIAL_CACHE_IDENTIFIERS = { - "flux": [ - r"transformer_blocks\.\d+\.attn", - r"transformer_blocks\.\d+\.ff", - r"transformer_blocks\.\d+\.ff_context", - r"single_transformer_blocks\.\d+\.proj_out", - ] -} -SKIP_COMPUTE_IDENTIFIERS = { - "flux": [ - r"single_transformer_blocks\.\d+\.attn", - r"single_transformer_blocks\.\d+\.proj_mlp", - r"single_transformer_blocks\.\d+\.act_mlp", - ] +# Predefined cache templates for optimized architectures +_CACHE_TEMPLATES = { + "flux": { + "cache": [ + r"transformer_blocks\.\d+\.attn", + r"transformer_blocks\.\d+\.ff", + r"transformer_blocks\.\d+\.ff_context", + r"single_transformer_blocks\.\d+\.proj_out", + ], + "skip": [ + r"single_transformer_blocks\.\d+\.attn", + r"single_transformer_blocks\.\d+\.proj_mlp", + r"single_transformer_blocks\.\d+\.act_mlp", + ], + }, } @@ -41,24 +46,39 @@ class TaylorSeerCacheConfig: Attributes: warmup_steps (int, defaults to 3): Number of warmup steps without caching. predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps. + stop_predicts (Optional[int], defaults to None): Step after which predictions are stopped and full computation is always performed. max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features. taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors. architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers. - skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation. - special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache. + skip_identifiers (List[str], defaults to []): Identifiers for modules to skip computation. + cache_identifiers (List[str], defaults to []): Identifiers for modules to cache. + + By default, this approximation can be applied to all attention modules, but in some architectures, where the outputs of attention modules are not used for any residual computation, we can skip this attention cache step, so we have to identify the next modules to cache. + Example: + ```python + ... + def forward(self, x: torch.Tensor) -> torch.Tensor: + attn_output = self.attention(x) # mark this attention module to skip computation + ffn_output = self.ffn(attn_output) # ffn_output will be cached + return ffn_output + ``` """ warmup_steps: int = 3 predict_steps: int = 5 + stop_predicts: Optional[int] = None max_order: int = 1 taylor_factors_dtype: torch.dtype = torch.float32 architecture: str | None = None - skip_compute_identifiers: List[str] = None - special_cache_identifiers: List[str] = None + skip_identifiers: List[str] = None + cache_identifiers: List[str] = None def __repr__(self) -> str: - return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})" + return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, stop_predicts={self.stop_predicts}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_identifiers={self.skip_identifiers}, cache_identifiers={self.cache_identifiers})" + @classmethod + def get_identifiers_template(self) -> Dict[str, Dict[str, List[str]]]: + return _CACHE_TEMPLATES class TaylorSeerOutputState: """ @@ -174,18 +194,20 @@ def __init__( max_order: int, warmup_steps: int, taylor_factors_dtype: torch.dtype, - is_skip_compute: bool = False, + stop_predicts: Optional[int] = None, + is_skip: bool = False, ): super().__init__() self.module_name = module_name self.predict_steps = predict_steps self.max_order = max_order self.warmup_steps = warmup_steps + self.stop_predicts = stop_predicts self.step_counter = -1 self.states: Optional[List[TaylorSeerOutputState]] = None self.num_outputs: Optional[int] = None self.taylor_factors_dtype = taylor_factors_dtype - self.is_skip_compute = is_skip_compute + self.is_skip = is_skip def initialize_hook(self, module: torch.nn.Module): self.step_counter = -1 @@ -208,7 +230,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): self.num_outputs = len(attention_outputs) self.states = [ TaylorSeerOutputState( - self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute + self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip ) for _ in range(self.num_outputs) ] @@ -218,22 +240,31 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): ) return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) - should_predict = self.states[0].remaining_predictions > 0 - if is_warmup_phase or not should_predict: - # Full compute during warmup or when refresh needed + if self.stop_predicts is not None and self.step_counter >= self.stop_predicts: + # After stop_predicts: always full compute without updating state attention_outputs = self.fn_ref.original_forward(*args, **kwargs) if isinstance(attention_outputs, torch.Tensor): attention_outputs = [attention_outputs] else: attention_outputs = list(attention_outputs) - is_first_update = self.step_counter == 0 # Only True for the very first step - for i, features in enumerate(attention_outputs): - self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update) return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) else: - # Predict using Taylor series - predicted_outputs = [state.predict(self.step_counter) for state in self.states] - return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) + should_predict = self.states[0].remaining_predictions > 0 + if is_warmup_phase or not should_predict: + # Full compute during warmup or when refresh needed + attention_outputs = self.fn_ref.original_forward(*args, **kwargs) + if isinstance(attention_outputs, torch.Tensor): + attention_outputs = [attention_outputs] + else: + attention_outputs = list(attention_outputs) + is_first_update = self.step_counter == 0 # Only True for the very first step + for i, features in enumerate(attention_outputs): + self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update) + return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) + else: + # Predict using Taylor series + predicted_outputs = [state.predict(self.step_counter) for state in self.states] + return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: self.states = None @@ -259,23 +290,23 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi >>> apply_taylorseer_cache(pipe.transformer, config) ``` """ - if config.skip_compute_identifiers: - skip_compute_identifiers = config.skip_compute_identifiers + if config.skip_identifiers: + skip_identifiers = config.skip_identifiers else: - skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, []) + skip_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", []) - if config.special_cache_identifiers: - special_cache_identifiers = config.special_cache_identifiers + if config.cache_identifiers: + cache_identifiers = config.cache_identifiers else: - special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, []) + cache_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("cache", []) - logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}") - logger.debug(f"Special cache identifiers: {special_cache_identifiers}") + logger.debug(f"Skip identifiers: {skip_identifiers}") + logger.debug(f"Cache identifiers: {cache_identifiers}") for name, submodule in module.named_modules(): - if (skip_compute_identifiers and special_cache_identifiers) or (special_cache_identifiers): - if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any( - re.fullmatch(identifier, name) for identifier in special_cache_identifiers + if (skip_identifiers and cache_identifiers) or (cache_identifiers): + if any(re.fullmatch(identifier, name) for identifier in skip_identifiers) or any( + re.fullmatch(identifier, name) for identifier in cache_identifiers ): logger.debug(f"Applying TaylorSeer cache to {name}") _apply_taylorseer_cache_hook(name, submodule, config) @@ -293,8 +324,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee config (TaylorSeerCacheConfig): Configuration for the cache. """ - is_skip_compute = any( - re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, []) + is_skip = any( + re.fullmatch(identifier, name) for identifier in _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", []) ) registry = HookRegistry.check_if_exists_or_initialize(module) @@ -305,7 +336,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee config.max_order, config.warmup_steps, config.taylor_factors_dtype, - is_skip_compute=is_skip_compute, + stop_predicts=config.stop_predicts, + is_skip=is_skip, ) - registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) + registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file From acfebfa3f3ed4ea3d3db751500ef50cbc38e398c Mon Sep 17 00:00:00 2001 From: toilaluan Date: Mon, 17 Nov 2025 13:21:01 +0700 Subject: [PATCH 09/29] update docs --- src/diffusers/hooks/taylorseer_cache.py | 26 ++++++++----------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index cb6b7fedd527..ec7705850e9f 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -44,24 +44,14 @@ class TaylorSeerCacheConfig: See: https://huggingface.co/papers/2503.06923 Attributes: - warmup_steps (int, defaults to 3): Number of warmup steps without caching. - predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps. - stop_predicts (Optional[int], defaults to None): Step after which predictions are stopped and full computation is always performed. - max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features. - taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors. - architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers. - skip_identifiers (List[str], defaults to []): Identifiers for modules to skip computation. - cache_identifiers (List[str], defaults to []): Identifiers for modules to cache. - - By default, this approximation can be applied to all attention modules, but in some architectures, where the outputs of attention modules are not used for any residual computation, we can skip this attention cache step, so we have to identify the next modules to cache. - Example: - ```python - ... - def forward(self, x: torch.Tensor) -> torch.Tensor: - attn_output = self.attention(x) # mark this attention module to skip computation - ffn_output = self.ffn(attn_output) # ffn_output will be cached - return ffn_output - ``` + warmup_steps (`int`, defaults to `3`): Calculate normal computations `N` times before applying this caching strategy. Higher `N` gives more closed outputs. + predict_steps (`int`, defaults to `5`): Calculate the module states every `N` iterations. If this is set to `N`, the module computation will be skipped `N - 1` times before computing the new module states again. + stop_predicts (`int`, *optional*, defaults to `None`): Disable caching strategy after this step, this feature helps produce fine-grained outputs. If not provided, the caching strategy will be applied until the end of the inference. + max_order (`int`, defaults to `1`): Maximum order of Taylor series expansion to approximate the features. In theory, the higher the order, the more closed the output is to the actual value but also the more computation is required. + taylor_factors_dtype (`torch.dtype`, defaults to `torch.float32`): Data type for calculating Taylor series expansion factors. + architecture (`str`, *optional*, defaults to `None`): Option to use cache strategy optimized for specific architectures. By default, this cache strategy will be applied to all `Attention` modules. + skip_identifiers (`List[str]`, *optional*, defaults to `[]`): Regex patterns to identify modules to skip computation. + cache_identifiers (`List[str]`, *optional*, defaults to `[]`): Regex patterns to identify modules to cache. """ warmup_steps: int = 3 From d929ab28a7c2bb4804c945429fefbc13816d8e86 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Mon, 17 Nov 2025 13:24:20 +0700 Subject: [PATCH 10/29] apply ruff --- src/diffusers/hooks/taylorseer_cache.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index ec7705850e9f..54191c30aab5 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -1,6 +1,6 @@ import torch from dataclasses import dataclass -from typing import Callable, Optional, List, Dict, Tuple +from typing import Optional, List, Dict, Tuple from .hooks import ModelHook import math from ..models.attention import Attention @@ -11,11 +11,9 @@ from ..hooks import HookRegistry from ..utils import logging import re -from collections import defaultdict -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - +logger = logging.get_logger(__name__) _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" @@ -70,6 +68,7 @@ def __repr__(self) -> str: def get_identifiers_template(self) -> Dict[str, Dict[str, List[str]]]: return _CACHE_TEMPLATES + class TaylorSeerOutputState: """ Manages the state for Taylor series-based prediction of a single attention output. @@ -219,9 +218,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): module_dtype = attention_outputs[0].dtype self.num_outputs = len(attention_outputs) self.states = [ - TaylorSeerOutputState( - self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip - ) + TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip) for _ in range(self.num_outputs) ] for i, features in enumerate(attention_outputs): @@ -249,7 +246,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): attention_outputs = list(attention_outputs) is_first_update = self.step_counter == 0 # Only True for the very first step for i, features in enumerate(attention_outputs): - self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update) + self.states[i].update( + features, self.step_counter, self.max_order, self.predict_steps, is_first_update + ) return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) else: # Predict using Taylor series @@ -330,4 +329,4 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee is_skip=is_skip, ) - registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file + registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) From 9083e1eba5f8b40e0ce5998ed341051408772aa4 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Thu, 20 Nov 2025 09:54:29 +0000 Subject: [PATCH 11/29] update to handle multple calls per timestep --- src/diffusers/hooks/taylorseer_cache.py | 573 +++++++++++++++++------- 1 file changed, 407 insertions(+), 166 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 54191c30aab5..807e15558f70 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -1,16 +1,16 @@ -import torch +import math +import re from dataclasses import dataclass from typing import Optional, List, Dict, Tuple + +import torch + from .hooks import ModelHook -import math from ..models.attention import Attention from ..models.attention import AttentionModuleMixin -from ._common import ( - _ATTENTION_CLASSES, -) +from ._common import _ATTENTION_CLASSES from ..hooks import HookRegistry from ..utils import logging -import re logger = logging.get_logger(__name__) @@ -18,7 +18,7 @@ _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" # Predefined cache templates for optimized architectures -_CACHE_TEMPLATES = { +_CACHE_TEMPLATES: Dict[str, Dict[str, List[str]]] = { "flux": { "cache": [ r"transformer_blocks\.\d+\.attn", @@ -42,55 +42,119 @@ class TaylorSeerCacheConfig: See: https://huggingface.co/papers/2503.06923 Attributes: - warmup_steps (`int`, defaults to `3`): Calculate normal computations `N` times before applying this caching strategy. Higher `N` gives more closed outputs. - predict_steps (`int`, defaults to `5`): Calculate the module states every `N` iterations. If this is set to `N`, the module computation will be skipped `N - 1` times before computing the new module states again. - stop_predicts (`int`, *optional*, defaults to `None`): Disable caching strategy after this step, this feature helps produce fine-grained outputs. If not provided, the caching strategy will be applied until the end of the inference. - max_order (`int`, defaults to `1`): Maximum order of Taylor series expansion to approximate the features. In theory, the higher the order, the more closed the output is to the actual value but also the more computation is required. - taylor_factors_dtype (`torch.dtype`, defaults to `torch.float32`): Data type for calculating Taylor series expansion factors. - architecture (`str`, *optional*, defaults to `None`): Option to use cache strategy optimized for specific architectures. By default, this cache strategy will be applied to all `Attention` modules. - skip_identifiers (`List[str]`, *optional*, defaults to `[]`): Regex patterns to identify modules to skip computation. - cache_identifiers (`List[str]`, *optional*, defaults to `[]`): Regex patterns to identify modules to cache. + warmup_steps (`int`, defaults to `3`): + Number of *outer* diffusion steps to run with full computation + before enabling caching. During warmup, the Taylor series factors + are still updated, but no predictions are used. + + predict_steps (`int`, defaults to `5`): + Number of prediction (cached) steps to take between two full + computations. That is, once a module state is refreshed, it will + be reused for `predict_steps` subsequent outer steps, then a new + full forward will be computed on the next step. + + stop_predicts (`int`, *optional*, defaults to `None`): + Outer diffusion step index at which caching is disabled. + If provided, for `true_step >= stop_predicts` all modules are + evaluated normally (no predictions, no state updates). + + max_order (`int`, defaults to `1`): + Maximum order of Taylor series expansion to approximate the + features. Higher order gives closer approximation but more compute. + + num_inner_loops (`int`, defaults to `1`): + Number of inner loops per outer diffusion step. For example, + with classifier-free guidance (CFG) you typically have 2 inner + loops: unconditional and conditional branches. + + taylor_factors_dtype (`torch.dtype`, defaults to `torch.float32`): + Data type for computing Taylor series expansion factors. + + architecture (`str`, *optional*, defaults to `None`): + If provided, will look up default `cache` and `skip` regex + patterns in `_CACHE_TEMPLATES[architecture]`. These can be + overridden by `skip_identifiers` and `cache_identifiers`. + + skip_identifiers (`List[str]`, *optional*, defaults to `None`): + Regex patterns (fullmatch) for module names to be placed in + "skip" mode, where the module is evaluated during warmup / + refresh, but then replaced by a cheap dummy tensor during + prediction steps. + + cache_identifiers (`List[str]`, *optional*, defaults to `None`): + Regex patterns (fullmatch) for module names to be placed in + Taylor-series caching mode. + + Notes: + - Patterns are applied with `re.fullmatch` on `module_name`. + - If either `skip_identifiers` or `cache_identifiers` is provided + (or inferred from `architecture`), only modules matching at least + one of those patterns will be hooked. + - If neither is provided, all attention-like modules will be hooked. """ warmup_steps: int = 3 predict_steps: int = 5 stop_predicts: Optional[int] = None max_order: int = 1 + num_inner_loops: int = 1 taylor_factors_dtype: torch.dtype = torch.float32 architecture: str | None = None - skip_identifiers: List[str] = None - cache_identifiers: List[str] = None + skip_identifiers: Optional[List[str]] = None + cache_identifiers: Optional[List[str]] = None def __repr__(self) -> str: - return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, stop_predicts={self.stop_predicts}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_identifiers={self.skip_identifiers}, cache_identifiers={self.cache_identifiers})" + return ( + "TaylorSeerCacheConfig(" + f"warmup_steps={self.warmup_steps}, " + f"predict_steps={self.predict_steps}, " + f"stop_predicts={self.stop_predicts}, " + f"max_order={self.max_order}, " + f"num_inner_loops={self.num_inner_loops}, " + f"taylor_factors_dtype={self.taylor_factors_dtype}, " + f"architecture={self.architecture}, " + f"skip_identifiers={self.skip_identifiers}, " + f"cache_identifiers={self.cache_identifiers})" + ) @classmethod - def get_identifiers_template(self) -> Dict[str, Dict[str, List[str]]]: + def get_identifiers_template(cls) -> Dict[str, Dict[str, List[str]]]: return _CACHE_TEMPLATES class TaylorSeerOutputState: """ Manages the state for Taylor series-based prediction of a single attention output. + Tracks Taylor expansion factors, last update step, and remaining prediction steps. - The Taylor expansion uses the timestep as the independent variable for approximation. + The Taylor expansion uses the (outer) timestep as the independent variable. + + This class is designed to handle state for a single inner loop index and a single + output (in cases where the module forward returns multiple tensors). """ def __init__( - self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype, is_skip: bool = False + self, + module_name: str, + taylor_factors_dtype: torch.dtype, + module_dtype: torch.dtype, + is_skip: bool = False, ): self.module_name = module_name - self.remaining_predictions: int = 0 - self.last_update_step: Optional[int] = None - self.taylor_factors: Dict[int, torch.Tensor] = {} self.taylor_factors_dtype = taylor_factors_dtype self.module_dtype = module_dtype self.is_skip = is_skip + + self.remaining_predictions: int = 0 + self.last_update_step: Optional[int] = None + self.taylor_factors: Dict[int, torch.Tensor] = {} + + # For skip-mode modules self.dummy_shape: Optional[Tuple[int, ...]] = None self.device: Optional[torch.device] = None self.dummy_tensor: Optional[torch.Tensor] = None - def reset(self): + def reset(self) -> None: self.remaining_predictions = 0 self.last_update_step = None self.taylor_factors = {} @@ -99,79 +163,137 @@ def reset(self): self.dummy_tensor = None def update( - self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool - ): + self, + features: torch.Tensor, + current_step: int, + max_order: int, + predict_steps: int, + ) -> None: """ - Updates the Taylor factors based on the current features and timestep. - Computes finite difference approximations for derivatives using recursive divided differences. + Update Taylor factors based on the current features and (outer) timestep. + + For non-skip modules, finite difference approximations for derivatives are + computed using recursive divided differences. Args: - features (torch.Tensor): The attention output features to update with. - current_step (int): The current timestep or step number from the diffusion model. - max_order (int): Maximum order of the Taylor expansion. - predict_steps (int): Number of prediction steps to set after update. - is_first_update (bool): Whether this is the initial update (skips difference computation). + features: Attention output features to update with. + current_step: Current outer timestep (true diffusion step). + max_order: Maximum Taylor expansion order. + predict_steps: Number of prediction steps to allow after this update. """ if self.is_skip: + # For skip modules we only need shape & device and a dummy tensor. self.dummy_shape = features.shape self.device = features.device + # zero is safer than uninitialized values for a "skipped" module + self.dummy_tensor = torch.zeros( + self.dummy_shape, + dtype=self.module_dtype, + device=self.device, + ) self.taylor_factors = {} self.last_update_step = current_step self.remaining_predictions = predict_steps - else: - features = features.to(self.taylor_factors_dtype) - new_factors = {0: features} - if not is_first_update: - if self.last_update_step is None: - raise ValueError("Cannot update without prior initialization.") - delta_step = current_step - self.last_update_step - if delta_step == 0: - raise ValueError("Delta step cannot be zero for updates.") - for i in range(max_order): - if i in self.taylor_factors: - new_factors[i + 1] = ( - new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype) - ) / delta_step - else: - break - - # taylor factors will be kept in the taylor_factors_dtype - self.taylor_factors = new_factors - self.last_update_step = current_step - self.remaining_predictions = predict_steps + return + + features = features.to(self.taylor_factors_dtype) + new_factors: Dict[int, torch.Tensor] = {0: features} + + is_first_update = self.last_update_step is None + + if not is_first_update: + delta_step = current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for TaylorSeer update.") + + # Recursive divided differences up to max_order + for i in range(max_order): + prev = self.taylor_factors.get(i) + if prev is None: + break + new_factors[i + 1] = (new_factors[i] - prev.to(self.taylor_factors_dtype)) / delta_step + + # Keep factors in taylor_factors_dtype + self.taylor_factors = new_factors + self.last_update_step = current_step + self.remaining_predictions = predict_steps + + if self.module_name == "proj_out": + logger.debug( + "[UPDATE] module=%s remaining_predictions=%d current_step=%d is_first_update=%s", + self.module_name, + self.remaining_predictions, + current_step, + is_first_update, + ) def predict(self, current_step: int) -> torch.Tensor: """ - Predicts the features using the Taylor series expansion at the given timestep. + Predict features using the Taylor series at the given (outer) timestep. Args: - current_step (int): The current timestep for prediction. + current_step: Current outer timestep for prediction. Returns: - torch.Tensor: The predicted features in the module's dtype. + Predicted features in the module's dtype. """ if self.is_skip: - if self.dummy_shape is None or self.device is None: + if self.dummy_tensor is None: raise ValueError("Cannot predict for skip module without prior update.") self.remaining_predictions -= 1 - return torch.empty(self.dummy_shape, dtype=self.module_dtype, device=self.device) - else: - if self.last_update_step is None: - raise ValueError("Cannot predict without prior update.") - step_offset = current_step - self.last_update_step - output = 0 - for order in range(len(self.taylor_factors)): - output += self.taylor_factors[order] * (step_offset**order) * (1 / math.factorial(order)) - self.remaining_predictions -= 1 - # output will be converted to the module's dtype - return output.to(self.module_dtype) + return self.dummy_tensor + + if self.last_update_step is None: + raise ValueError("Cannot predict without prior initialization/update.") + + step_offset = current_step - self.last_update_step + + output: torch.Tensor + if not self.taylor_factors: + raise ValueError("Taylor factors empty during prediction.") + + # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) + output = torch.zeros_like(self.taylor_factors[0]) + for order, factor in self.taylor_factors.items(): + # Note: order starts at 0 + coeff = (step_offset**order) / math.factorial(order) + output = output + factor * coeff + + self.remaining_predictions -= 1 + out = output.to(self.module_dtype) + + if self.module_name == "proj_out": + logger.debug( + "[PREDICT] module=%s remaining_predictions=%d current_step=%d last_update_step=%s", + self.module_name, + self.remaining_predictions, + current_step, + self.last_update_step, + ) + + return out class TaylorSeerAttentionCacheHook(ModelHook): """ Hook for caching and predicting attention outputs using Taylor series approximations. + Applies to attention modules in diffusion models (e.g., Flux). - Performs full computations during warmup, then alternates between predictions and refreshes. + Performs full computations during warmup, then alternates between blocks of + predictions and refreshes. + + The hook maintains separate states for each inner loop index (e.g., for + classifier-free guidance). Each inner loop has its own list of + `TaylorSeerOutputState` instances, one per output tensor from the module's + forward (typically one). + + The `step_counter` increments on every forward call of this module. + We define: + - `inner_index = step_counter % num_inner_loops` + - `true_step = step_counter // num_inner_loops` + + Warmup, prediction, and updates are handled per inner loop, but use the + shared `true_step` (outer diffusion step). """ _is_stateful = True @@ -183,148 +305,267 @@ def __init__( max_order: int, warmup_steps: int, taylor_factors_dtype: torch.dtype, + num_inner_loops: int = 1, stop_predicts: Optional[int] = None, is_skip: bool = False, ): super().__init__() + if num_inner_loops <= 0: + raise ValueError("num_inner_loops must be >= 1") + self.module_name = module_name self.predict_steps = predict_steps self.max_order = max_order self.warmup_steps = warmup_steps self.stop_predicts = stop_predicts - self.step_counter = -1 - self.states: Optional[List[TaylorSeerOutputState]] = None - self.num_outputs: Optional[int] = None + self.num_inner_loops = num_inner_loops self.taylor_factors_dtype = taylor_factors_dtype self.is_skip = is_skip + self.step_counter: int = -1 + self.states: Optional[List[Optional[List[TaylorSeerOutputState]]]] = None + self.num_outputs: Optional[int] = None + def initialize_hook(self, module: torch.nn.Module): self.step_counter = -1 self.states = None self.num_outputs = None return module - def new_forward(self, module: torch.nn.Module, *args, **kwargs): - self.step_counter += 1 - is_warmup_phase = self.step_counter < self.warmup_steps + def reset_state(self, module: torch.nn.Module) -> None: + """ + Reset state between sampling runs. + """ + self.step_counter = -1 + self.states = None + self.num_outputs = None + + @staticmethod + def _listify(outputs): + if isinstance(outputs, torch.Tensor): + return [outputs] + return list(outputs) + def _delistify(self, outputs_list): + if self.num_outputs == 1: + return outputs_list[0] + return tuple(outputs_list) + + def _ensure_states_initialized( + self, + module: torch.nn.Module, + inner_index: int, + true_step: int, + *args, + **kwargs, + ) -> Optional[List[torch.Tensor]]: + """ + Ensure per-inner-loop states exist. If this is the first call for this + inner_index, perform a full forward, initialize states, and return the + outputs. Otherwise, return None. + """ if self.states is None: - # First step: always full compute and initialize - attention_outputs = self.fn_ref.original_forward(*args, **kwargs) - if isinstance(attention_outputs, torch.Tensor): - attention_outputs = [attention_outputs] - else: - attention_outputs = list(attention_outputs) - module_dtype = attention_outputs[0].dtype + self.states = [None for _ in range(self.num_inner_loops)] + + if self.states[inner_index] is not None: + return None + + if self.module_name == "proj_out": + logger.debug( + "[FIRST STEP] Initializing states for %s (inner_index=%d, true_step=%d)", + self.module_name, + inner_index, + true_step, + ) + + # First step for this inner loop: always full compute and initialize. + attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs)) + module_dtype = attention_outputs[0].dtype + + if self.num_outputs is None: self.num_outputs = len(attention_outputs) - self.states = [ - TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip) - for _ in range(self.num_outputs) - ] - for i, features in enumerate(attention_outputs): - self.states[i].update( - features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True - ) - return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) - - if self.stop_predicts is not None and self.step_counter >= self.stop_predicts: - # After stop_predicts: always full compute without updating state - attention_outputs = self.fn_ref.original_forward(*args, **kwargs) - if isinstance(attention_outputs, torch.Tensor): - attention_outputs = [attention_outputs] - else: - attention_outputs = list(attention_outputs) - return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) - else: - should_predict = self.states[0].remaining_predictions > 0 - if is_warmup_phase or not should_predict: - # Full compute during warmup or when refresh needed - attention_outputs = self.fn_ref.original_forward(*args, **kwargs) - if isinstance(attention_outputs, torch.Tensor): - attention_outputs = [attention_outputs] - else: - attention_outputs = list(attention_outputs) - is_first_update = self.step_counter == 0 # Only True for the very first step - for i, features in enumerate(attention_outputs): - self.states[i].update( - features, self.step_counter, self.max_order, self.predict_steps, is_first_update - ) - return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) - else: - # Predict using Taylor series - predicted_outputs = [state.predict(self.step_counter) for state in self.states] - return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) + elif self.num_outputs != len(attention_outputs): + raise ValueError("Output count mismatch across inner loops.") + + self.states[inner_index] = [ + TaylorSeerOutputState( + self.module_name, + self.taylor_factors_dtype, + module_dtype, + is_skip=self.is_skip, + ) + for _ in range(self.num_outputs) + ] + + for i, features in enumerate(attention_outputs): + self.states[inner_index][i].update( + features=features, + current_step=true_step, + max_order=self.max_order, + predict_steps=self.predict_steps, + ) + + return attention_outputs - def reset_state(self, module: torch.nn.Module) -> None: - self.states = None + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + self.step_counter += 1 + inner_index = self.step_counter % self.num_inner_loops + true_step = self.step_counter // self.num_inner_loops + is_warmup_phase = true_step < self.warmup_steps + + if self.module_name == "proj_out": + logger.debug( + "[FORWARD] module=%s step_counter=%d inner_index=%d true_step=%d is_warmup=%s", + self.module_name, + self.step_counter, + inner_index, + true_step, + is_warmup_phase, + ) + + # First-time initialization for this inner loop + maybe_outputs = self._ensure_states_initialized(module, inner_index, true_step, *args, **kwargs) + if maybe_outputs is not None: + return self._delistify(maybe_outputs) + + assert self.states is not None + states = self.states[inner_index] + assert states is not None and len(states) > 0 + + # If stop_predicts is set and we are past that step, always run full forward + if self.stop_predicts is not None and true_step >= self.stop_predicts: + attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs)) + return self._delistify(attention_outputs) + + # Decide between prediction vs refresh + # - Never predict during warmup. + # - Otherwise, predict while we still have remaining_predictions. + should_predict = (not is_warmup_phase) and (states[0].remaining_predictions > 0) + + if should_predict: + predicted_outputs = [state.predict(true_step) for state in states] + return self._delistify(predicted_outputs) + + # Full compute: warmup or refresh + attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs)) + for i, features in enumerate(attention_outputs): + states[i].update( + features=features, + current_step=true_step, + max_order=self.max_order, + predict_steps=self.predict_steps, + ) + return self._delistify(attention_outputs) + + +def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]: + """ + Resolve effective skip and cache pattern lists from config + templates. + """ + template = _CACHE_TEMPLATES.get(config.architecture or "", {}) + default_skip = template.get("skip", []) + default_cache = template.get("cache", []) + + skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else default_skip + cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else default_cache + + return skip_patterns or [], cache_patterns or [] def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): """ - Applies the TaylorSeer cache to given pipeline. + Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet). Args: - module (torch.nn.Module): The model to apply the hook to. + module (torch.nn.Module): The model subtree to apply the hooks to. config (TaylorSeerCacheConfig): Configuration for the cache. Example: ```python >>> import torch - >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig, apply_taylorseer_cache - - >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig + >>> + >>> pipe = FluxPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-dev", + ... torch_dtype=torch.bfloat16, + ... ) >>> pipe.to("cuda") - - >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32, architecture="flux") - >>> apply_taylorseer_cache(pipe.transformer, config) + >>> + >>> config = TaylorSeerCacheConfig( + ... predict_steps=5, + ... max_order=1, + ... warmup_steps=3, + ... taylor_factors_dtype=torch.float32, + ... architecture="flux", + ... num_inner_loops=2, # e.g. CFG + ... ) + >>> pipe.transformer.enable_cache(config) ``` """ - if config.skip_identifiers: - skip_identifiers = config.skip_identifiers - else: - skip_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", []) + skip_patterns, cache_patterns = _resolve_patterns(config) - if config.cache_identifiers: - cache_identifiers = config.cache_identifiers - else: - cache_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("cache", []) + logger.debug("TaylorSeer skip identifiers: %s", skip_patterns) + logger.debug("TaylorSeer cache identifiers: %s", cache_patterns) - logger.debug(f"Skip identifiers: {skip_identifiers}") - logger.debug(f"Cache identifiers: {cache_identifiers}") + use_patterns = bool(skip_patterns or cache_patterns) for name, submodule in module.named_modules(): - if (skip_identifiers and cache_identifiers) or (cache_identifiers): - if any(re.fullmatch(identifier, name) for identifier in skip_identifiers) or any( - re.fullmatch(identifier, name) for identifier in cache_identifiers - ): - logger.debug(f"Applying TaylorSeer cache to {name}") - _apply_taylorseer_cache_hook(name, submodule, config) - elif isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): - logger.debug(f"Applying TaylorSeer cache to {name}") - _apply_taylorseer_cache_hook(name, submodule, config) - - -def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig): + matches_skip = any(re.fullmatch(pattern, name) for pattern in skip_patterns) + matches_cache = any(re.fullmatch(pattern, name) for pattern in cache_patterns) + + if use_patterns: + # If patterns are configured (either skip or cache), only touch modules + # that explicitly match at least one pattern. + if not (matches_skip or matches_cache): + continue + + logger.debug( + "Applying TaylorSeer cache to %s (mode=%s)", + name, + "skip" if matches_skip else "cache", + ) + _apply_taylorseer_cache_hook( + name=name, + module=submodule, + config=config, + is_skip=matches_skip, + ) + else: + # No patterns configured: fall back to "all attention modules". + if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): + logger.debug("Applying TaylorSeer cache to %s (fallback attention mode)", name) + _apply_taylorseer_cache_hook( + name=name, + module=submodule, + config=config, + is_skip=False, + ) + + +def _apply_taylorseer_cache_hook( + name: str, + module: Attention, + config: TaylorSeerCacheConfig, + is_skip: bool, +): """ Registers the TaylorSeer hook on the specified attention module. + Args: - name (str): Name of the module. - module (Attention): The attention module. - config (TaylorSeerCacheConfig): Configuration for the cache. + name: Name of the module. + module: The attention-like module to be hooked. + config: Cache configuration. + is_skip: Whether this module should operate in "skip" mode. """ - - is_skip = any( - re.fullmatch(identifier, name) for identifier in _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", []) - ) - registry = HookRegistry.check_if_exists_or_initialize(module) hook = TaylorSeerAttentionCacheHook( - name, - config.predict_steps, - config.max_order, - config.warmup_steps, - config.taylor_factors_dtype, + module_name=name, + predict_steps=config.predict_steps, + max_order=config.max_order, + warmup_steps=config.warmup_steps, + taylor_factors_dtype=config.taylor_factors_dtype, + num_inner_loops=config.num_inner_loops, stop_predicts=config.stop_predicts, is_skip=is_skip, ) From a8ea38304424cb2f3f19893223b3c197b755d843 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Tue, 25 Nov 2025 05:28:00 +0000 Subject: [PATCH 12/29] refractor to use state manager --- src/diffusers/hooks/taylorseer_cache.py | 504 +++++++----------------- src/diffusers/models/cache_utils.py | 4 +- 2 files changed, 139 insertions(+), 369 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 807e15558f70..f400576fed44 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -4,36 +4,25 @@ from typing import Optional, List, Dict, Tuple import torch +import torch.nn as nn -from .hooks import ModelHook -from ..models.attention import Attention -from ..models.attention import AttentionModuleMixin -from ._common import _ATTENTION_CLASSES -from ..hooks import HookRegistry +from .hooks import ModelHook, StateManager, HookRegistry from ..utils import logging logger = logging.get_logger(__name__) - -_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" - -# Predefined cache templates for optimized architectures -_CACHE_TEMPLATES: Dict[str, Dict[str, List[str]]] = { - "flux": { - "cache": [ - r"transformer_blocks\.\d+\.attn", - r"transformer_blocks\.\d+\.ff", - r"transformer_blocks\.\d+\.ff_context", - r"single_transformer_blocks\.\d+\.proj_out", - ], - "skip": [ - r"single_transformer_blocks\.\d+\.attn", - r"single_transformer_blocks\.\d+\.proj_mlp", - r"single_transformer_blocks\.\d+\.act_mlp", - ], - }, -} - +_TAYLORSEER_CACHE_HOOK = "taylorseer_cache" +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( + "^blocks.*attn", + "^transformer_blocks.*attn", + "^single_transformer_blocks.*attn", +) +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) +_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS +_BLOCK_IDENTIFIERS = ( + "^[^.]*block[^.]*\\.[^.]+$", +) +_PROJ_OUT_IDENTIFIERS = ("^proj_out$",) @dataclass class TaylorSeerCacheConfig: @@ -43,37 +32,29 @@ class TaylorSeerCacheConfig: Attributes: warmup_steps (`int`, defaults to `3`): - Number of *outer* diffusion steps to run with full computation + Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor series factors are still updated, but no predictions are used. predict_steps (`int`, defaults to `5`): Number of prediction (cached) steps to take between two full computations. That is, once a module state is refreshed, it will - be reused for `predict_steps` subsequent outer steps, then a new + be reused for `predict_steps` subsequent denoising steps, then a new full forward will be computed on the next step. stop_predicts (`int`, *optional*, defaults to `None`): - Outer diffusion step index at which caching is disabled. - If provided, for `true_step >= stop_predicts` all modules are + Denoising step index at which caching is disabled. + If provided, for `self.current_step >= stop_predicts` all modules are evaluated normally (no predictions, no state updates). max_order (`int`, defaults to `1`): Maximum order of Taylor series expansion to approximate the features. Higher order gives closer approximation but more compute. - num_inner_loops (`int`, defaults to `1`): - Number of inner loops per outer diffusion step. For example, - with classifier-free guidance (CFG) you typically have 2 inner - loops: unconditional and conditional branches. - - taylor_factors_dtype (`torch.dtype`, defaults to `torch.float32`): + taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): Data type for computing Taylor series expansion factors. - - architecture (`str`, *optional*, defaults to `None`): - If provided, will look up default `cache` and `skip` regex - patterns in `_CACHE_TEMPLATES[architecture]`. These can be - overridden by `skip_identifiers` and `cache_identifiers`. + Use lower precision to reduce memory usage. + Use higher precision to improve numerical stability. skip_identifiers (`List[str]`, *optional*, defaults to `None`): Regex patterns (fullmatch) for module names to be placed in @@ -85,10 +66,12 @@ class TaylorSeerCacheConfig: Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode. + lite (`bool`, *optional*, defaults to `False`): + Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides + any user-provided `skip_identifiers` or `cache_identifiers` patterns. Notes: - Patterns are applied with `re.fullmatch` on `module_name`. - - If either `skip_identifiers` or `cache_identifiers` is provided - (or inferred from `architecture`), only modules matching at least + - If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those patterns will be hooked. - If neither is provided, all attention-like modules will be hooked. """ @@ -97,11 +80,10 @@ class TaylorSeerCacheConfig: predict_steps: int = 5 stop_predicts: Optional[int] = None max_order: int = 1 - num_inner_loops: int = 1 - taylor_factors_dtype: torch.dtype = torch.float32 - architecture: str | None = None + taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16 skip_identifiers: Optional[List[str]] = None cache_identifiers: Optional[List[str]] = None + lite: bool = False def __repr__(self) -> str: return ( @@ -110,364 +92,153 @@ def __repr__(self) -> str: f"predict_steps={self.predict_steps}, " f"stop_predicts={self.stop_predicts}, " f"max_order={self.max_order}, " - f"num_inner_loops={self.num_inner_loops}, " f"taylor_factors_dtype={self.taylor_factors_dtype}, " - f"architecture={self.architecture}, " f"skip_identifiers={self.skip_identifiers}, " - f"cache_identifiers={self.cache_identifiers})" + f"cache_identifiers={self.cache_identifiers}, " + f"lite={self.lite})" ) - @classmethod - def get_identifiers_template(cls) -> Dict[str, Dict[str, List[str]]]: - return _CACHE_TEMPLATES - - -class TaylorSeerOutputState: - """ - Manages the state for Taylor series-based prediction of a single attention output. - - Tracks Taylor expansion factors, last update step, and remaining prediction steps. - The Taylor expansion uses the (outer) timestep as the independent variable. - - This class is designed to handle state for a single inner loop index and a single - output (in cases where the module forward returns multiple tensors). - """ +class TaylorSeerState: def __init__( self, - module_name: str, - taylor_factors_dtype: torch.dtype, - module_dtype: torch.dtype, - is_skip: bool = False, + taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16, + max_order: int = 1, ): - self.module_name = module_name self.taylor_factors_dtype = taylor_factors_dtype - self.module_dtype = module_dtype - self.is_skip = is_skip + self.max_order = max_order - self.remaining_predictions: int = 0 + self.module_dtypes: Tuple[torch.dtype, ...] = () self.last_update_step: Optional[int] = None - self.taylor_factors: Dict[int, torch.Tensor] = {} + self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {} # For skip-mode modules - self.dummy_shape: Optional[Tuple[int, ...]] = None self.device: Optional[torch.device] = None - self.dummy_tensor: Optional[torch.Tensor] = None + self.dummy_tensors: Optional[Tuple[torch.Tensor, ...]] = None + + self.current_step = -1 def reset(self) -> None: - self.remaining_predictions = 0 self.last_update_step = None self.taylor_factors = {} - self.dummy_shape = None self.device = None - self.dummy_tensor = None + self.dummy_tensors = None + self.current_step = -1 def update( self, - features: torch.Tensor, + outputs: Tuple[torch.Tensor, ...], current_step: int, - max_order: int, - predict_steps: int, ) -> None: - """ - Update Taylor factors based on the current features and (outer) timestep. - - For non-skip modules, finite difference approximations for derivatives are - computed using recursive divided differences. - - Args: - features: Attention output features to update with. - current_step: Current outer timestep (true diffusion step). - max_order: Maximum Taylor expansion order. - predict_steps: Number of prediction steps to allow after this update. - """ - if self.is_skip: - # For skip modules we only need shape & device and a dummy tensor. - self.dummy_shape = features.shape - self.device = features.device - # zero is safer than uninitialized values for a "skipped" module - self.dummy_tensor = torch.zeros( - self.dummy_shape, - dtype=self.module_dtype, - device=self.device, - ) - self.taylor_factors = {} - self.last_update_step = current_step - self.remaining_predictions = predict_steps - return - - features = features.to(self.taylor_factors_dtype) - new_factors: Dict[int, torch.Tensor] = {0: features} - - is_first_update = self.last_update_step is None - - if not is_first_update: - delta_step = current_step - self.last_update_step - if delta_step == 0: - raise ValueError("Delta step cannot be zero for TaylorSeer update.") - - # Recursive divided differences up to max_order - for i in range(max_order): - prev = self.taylor_factors.get(i) - if prev is None: - break - new_factors[i + 1] = (new_factors[i] - prev.to(self.taylor_factors_dtype)) / delta_step - - # Keep factors in taylor_factors_dtype - self.taylor_factors = new_factors + self.module_dtypes = tuple(output.dtype for output in outputs) + for i in range(len(outputs)): + features = outputs[i].to(self.taylor_factors_dtype) + new_factors: Dict[int, torch.Tensor] = {0: features} + is_first_update = self.last_update_step is None + if not is_first_update: + delta_step = current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for TaylorSeer update.") + + # Recursive divided differences up to max_order + for j in range(self.max_order): + prev = self.taylor_factors[i].get(j) + if prev is None: + break + new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step + self.taylor_factors[i] = new_factors self.last_update_step = current_step - self.remaining_predictions = predict_steps - - if self.module_name == "proj_out": - logger.debug( - "[UPDATE] module=%s remaining_predictions=%d current_step=%d is_first_update=%s", - self.module_name, - self.remaining_predictions, - current_step, - is_first_update, - ) def predict(self, current_step: int) -> torch.Tensor: - """ - Predict features using the Taylor series at the given (outer) timestep. - - Args: - current_step: Current outer timestep for prediction. - - Returns: - Predicted features in the module's dtype. - """ - if self.is_skip: - if self.dummy_tensor is None: - raise ValueError("Cannot predict for skip module without prior update.") - self.remaining_predictions -= 1 - return self.dummy_tensor - if self.last_update_step is None: raise ValueError("Cannot predict without prior initialization/update.") step_offset = current_step - self.last_update_step - output: torch.Tensor if not self.taylor_factors: raise ValueError("Taylor factors empty during prediction.") - # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) - output = torch.zeros_like(self.taylor_factors[0]) - for order, factor in self.taylor_factors.items(): - # Note: order starts at 0 - coeff = (step_offset**order) / math.factorial(order) - output = output + factor * coeff - - self.remaining_predictions -= 1 - out = output.to(self.module_dtype) + outputs = [] + for i in range(len(self.module_dtypes)): + taylor_factors = self.taylor_factors[i] + # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) + output = torch.zeros_like(taylor_factors[0]) + for order, factor in taylor_factors.items(): + # Note: order starts at 0 + coeff = (step_offset**order) / math.factorial(order) + output = output + factor * coeff + outputs.append(output.to(self.module_dtypes[i])) - if self.module_name == "proj_out": - logger.debug( - "[PREDICT] module=%s remaining_predictions=%d current_step=%d last_update_step=%s", - self.module_name, - self.remaining_predictions, - current_step, - self.last_update_step, - ) + return outputs - return out - - -class TaylorSeerAttentionCacheHook(ModelHook): - """ - Hook for caching and predicting attention outputs using Taylor series approximations. - - Applies to attention modules in diffusion models (e.g., Flux). - Performs full computations during warmup, then alternates between blocks of - predictions and refreshes. - - The hook maintains separate states for each inner loop index (e.g., for - classifier-free guidance). Each inner loop has its own list of - `TaylorSeerOutputState` instances, one per output tensor from the module's - forward (typically one). - - The `step_counter` increments on every forward call of this module. - We define: - - `inner_index = step_counter % num_inner_loops` - - `true_step = step_counter // num_inner_loops` - - Warmup, prediction, and updates are handled per inner loop, but use the - shared `true_step` (outer diffusion step). - """ +class TaylorSeerCacheHook(ModelHook): _is_stateful = True def __init__( self, module_name: str, predict_steps: int, - max_order: int, warmup_steps: int, taylor_factors_dtype: torch.dtype, - num_inner_loops: int = 1, + state_manager: StateManager, stop_predicts: Optional[int] = None, is_skip: bool = False, ): super().__init__() - if num_inner_loops <= 0: - raise ValueError("num_inner_loops must be >= 1") - self.module_name = module_name self.predict_steps = predict_steps - self.max_order = max_order self.warmup_steps = warmup_steps self.stop_predicts = stop_predicts - self.num_inner_loops = num_inner_loops self.taylor_factors_dtype = taylor_factors_dtype + self.state_manager = state_manager self.is_skip = is_skip - self.step_counter: int = -1 - self.states: Optional[List[Optional[List[TaylorSeerOutputState]]]] = None - self.num_outputs: Optional[int] = None + self.dummy_outputs = None def initialize_hook(self, module: torch.nn.Module): - self.step_counter = -1 - self.states = None - self.num_outputs = None return module def reset_state(self, module: torch.nn.Module) -> None: """ Reset state between sampling runs. """ - self.step_counter = -1 - self.states = None - self.num_outputs = None - - @staticmethod - def _listify(outputs): - if isinstance(outputs, torch.Tensor): - return [outputs] - return list(outputs) - - def _delistify(self, outputs_list): - if self.num_outputs == 1: - return outputs_list[0] - return tuple(outputs_list) - - def _ensure_states_initialized( - self, - module: torch.nn.Module, - inner_index: int, - true_step: int, - *args, - **kwargs, - ) -> Optional[List[torch.Tensor]]: - """ - Ensure per-inner-loop states exist. If this is the first call for this - inner_index, perform a full forward, initialize states, and return the - outputs. Otherwise, return None. - """ - if self.states is None: - self.states = [None for _ in range(self.num_inner_loops)] - - if self.states[inner_index] is not None: - return None - - if self.module_name == "proj_out": - logger.debug( - "[FIRST STEP] Initializing states for %s (inner_index=%d, true_step=%d)", - self.module_name, - inner_index, - true_step, - ) - - # First step for this inner loop: always full compute and initialize. - attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs)) - module_dtype = attention_outputs[0].dtype - - if self.num_outputs is None: - self.num_outputs = len(attention_outputs) - elif self.num_outputs != len(attention_outputs): - raise ValueError("Output count mismatch across inner loops.") - - self.states[inner_index] = [ - TaylorSeerOutputState( - self.module_name, - self.taylor_factors_dtype, - module_dtype, - is_skip=self.is_skip, - ) - for _ in range(self.num_outputs) - ] - - for i, features in enumerate(attention_outputs): - self.states[inner_index][i].update( - features=features, - current_step=true_step, - max_order=self.max_order, - predict_steps=self.predict_steps, - ) - - return attention_outputs + self.dummy_outputs = None + self.current_step = -1 + self.state_manager.reset() def new_forward(self, module: torch.nn.Module, *args, **kwargs): - self.step_counter += 1 - inner_index = self.step_counter % self.num_inner_loops - true_step = self.step_counter // self.num_inner_loops - is_warmup_phase = true_step < self.warmup_steps - - if self.module_name == "proj_out": - logger.debug( - "[FORWARD] module=%s step_counter=%d inner_index=%d true_step=%d is_warmup=%s", - self.module_name, - self.step_counter, - inner_index, - true_step, - is_warmup_phase, - ) - - # First-time initialization for this inner loop - maybe_outputs = self._ensure_states_initialized(module, inner_index, true_step, *args, **kwargs) - if maybe_outputs is not None: - return self._delistify(maybe_outputs) - - assert self.states is not None - states = self.states[inner_index] - assert states is not None and len(states) > 0 - - # If stop_predicts is set and we are past that step, always run full forward - if self.stop_predicts is not None and true_step >= self.stop_predicts: - attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs)) - return self._delistify(attention_outputs) - - # Decide between prediction vs refresh - # - Never predict during warmup. - # - Otherwise, predict while we still have remaining_predictions. - should_predict = (not is_warmup_phase) and (states[0].remaining_predictions > 0) - - if should_predict: - predicted_outputs = [state.predict(true_step) for state in states] - return self._delistify(predicted_outputs) - - # Full compute: warmup or refresh - attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs)) - for i, features in enumerate(attention_outputs): - states[i].update( - features=features, - current_step=true_step, - max_order=self.max_order, - predict_steps=self.predict_steps, - ) - return self._delistify(attention_outputs) + state: TaylorSeerState = self.state_manager.get_state() + state.current_step += 1 + current_step = state.current_step + is_warmup_phase = current_step < self.warmup_steps + should_compute = ( + is_warmup_phase + or ((current_step - self.warmup_steps - 1) % self.predict_steps == 0) + or (self.stop_predicts is not None and current_step >= self.stop_predicts) + ) + if should_compute: + outputs = self.fn_ref.original_forward(*args, **kwargs) + if not self.is_skip: + state.update((outputs,) if isinstance(outputs, torch.Tensor) else outputs, current_step) + else: + self.dummy_outputs = outputs + return outputs + + if self.is_skip: + return self.dummy_outputs + + outputs = state.predict(current_step) + return outputs[0] if len(outputs) == 1 else outputs def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]: """ Resolve effective skip and cache pattern lists from config + templates. """ - template = _CACHE_TEMPLATES.get(config.architecture or "", {}) - default_skip = template.get("skip", []) - default_cache = template.get("cache", []) - skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else default_skip - cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else default_cache + skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else None + cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else None return skip_patterns or [], cache_patterns or [] @@ -496,8 +267,6 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi ... max_order=1, ... warmup_steps=3, ... taylor_factors_dtype=torch.float32, - ... architecture="flux", - ... num_inner_loops=2, # e.g. CFG ... ) >>> pipe.transformer.enable_cache(config) ``` @@ -507,67 +276,68 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi logger.debug("TaylorSeer skip identifiers: %s", skip_patterns) logger.debug("TaylorSeer cache identifiers: %s", cache_patterns) - use_patterns = bool(skip_patterns or cache_patterns) + cache_patterns = cache_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS + + if config.lite: + logger.info("Using TaylorSeer Lite variant for cache.") + cache_patterns = _PROJ_OUT_IDENTIFIERS + skip_patterns = _BLOCK_IDENTIFIERS + if config.skip_identifiers or config.cache_identifiers: + logger.warning("Lite mode overrides user patterns.") for name, submodule in module.named_modules(): matches_skip = any(re.fullmatch(pattern, name) for pattern in skip_patterns) matches_cache = any(re.fullmatch(pattern, name) for pattern in cache_patterns) - - if use_patterns: - # If patterns are configured (either skip or cache), only touch modules - # that explicitly match at least one pattern. - if not (matches_skip or matches_cache): - continue - - logger.debug( - "Applying TaylorSeer cache to %s (mode=%s)", - name, - "skip" if matches_skip else "cache", - ) - _apply_taylorseer_cache_hook( - name=name, - module=submodule, - config=config, - is_skip=matches_skip, - ) - else: - # No patterns configured: fall back to "all attention modules". - if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): - logger.debug("Applying TaylorSeer cache to %s (fallback attention mode)", name) - _apply_taylorseer_cache_hook( - name=name, - module=submodule, - config=config, - is_skip=False, - ) + if not (matches_skip or matches_cache): + continue + logger.debug( + "Applying TaylorSeer cache to %s (mode=%s)", + name, + "skip" if matches_skip else "cache", + ) + state_manager = StateManager( + TaylorSeerState, + init_kwargs={ + "taylor_factors_dtype": config.taylor_factors_dtype, + "max_order": config.max_order, + }, + ) + _apply_taylorseer_cache_hook( + name=name, + module=submodule, + config=config, + is_skip=matches_skip, + state_manager=state_manager, + ) def _apply_taylorseer_cache_hook( name: str, - module: Attention, + module: nn.Module, config: TaylorSeerCacheConfig, is_skip: bool, + state_manager: StateManager, ): """ - Registers the TaylorSeer hook on the specified attention module. + Registers the TaylorSeer hook on the specified nn.Module. Args: name: Name of the module. - module: The attention-like module to be hooked. + module: The nn.Module to be hooked. config: Cache configuration. is_skip: Whether this module should operate in "skip" mode. + state_manager: The state manager for managing hook state. """ registry = HookRegistry.check_if_exists_or_initialize(module) - hook = TaylorSeerAttentionCacheHook( + hook = TaylorSeerCacheHook( module_name=name, predict_steps=config.predict_steps, - max_order=config.max_order, warmup_steps=config.warmup_steps, taylor_factors_dtype=config.taylor_factors_dtype, - num_inner_loops=config.num_inner_loops, stop_predicts=config.stop_predicts, is_skip=is_skip, + state_manager=state_manager, ) - registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) + registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index ffbf296ff617..56de1e646310 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -97,7 +97,7 @@ def disable_cache(self) -> None: from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK - from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK + from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -113,7 +113,7 @@ def disable_cache(self) -> None: elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) elif isinstance(self._cache_config, TaylorSeerCacheConfig): - registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True) + registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") From 2be31f856e628fde5dfd8f96fa04401dca2397fa Mon Sep 17 00:00:00 2001 From: toilaluan Date: Tue, 25 Nov 2025 06:02:13 +0000 Subject: [PATCH 13/29] fix format & doc --- docs/source/en/api/cache.md | 6 +++ src/diffusers/hooks/__init__.py | 2 +- src/diffusers/hooks/taylorseer_cache.py | 59 +++++++++++-------------- src/diffusers/models/cache_utils.py | 8 +++- 4 files changed, 39 insertions(+), 36 deletions(-) diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index 9ba474208551..c93dcad43821 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -34,3 +34,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate [[autodoc]] FirstBlockCacheConfig [[autodoc]] apply_first_block_cache + +### TaylorSeerCacheConfig + +[[autodoc]] TaylorSeerCacheConfig + +[[autodoc]] apply_taylorseer_cache diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 1d9d43d96b2a..eb12b8a52a1e 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -25,4 +25,4 @@ from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig - from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache \ No newline at end of file + from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index f400576fed44..17d102f589c6 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -1,13 +1,13 @@ import math import re from dataclasses import dataclass -from typing import Optional, List, Dict, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from .hooks import ModelHook, StateManager, HookRegistry from ..utils import logging +from .hooks import HookRegistry, ModelHook, StateManager logger = logging.get_logger(__name__) @@ -19,60 +19,51 @@ ) _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) _TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS -_BLOCK_IDENTIFIERS = ( - "^[^.]*block[^.]*\\.[^.]+$", -) +_BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",) _PROJ_OUT_IDENTIFIERS = ("^proj_out$",) + @dataclass class TaylorSeerCacheConfig: """ - Configuration for TaylorSeer cache. - See: https://huggingface.co/papers/2503.06923 + Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923 Attributes: warmup_steps (`int`, defaults to `3`): - Number of denoising steps to run with full computation - before enabling caching. During warmup, the Taylor series factors - are still updated, but no predictions are used. + Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor + series factors are still updated, but no predictions are used. predict_steps (`int`, defaults to `5`): - Number of prediction (cached) steps to take between two full - computations. That is, once a module state is refreshed, it will - be reused for `predict_steps` subsequent denoising steps, then a new - full forward will be computed on the next step. + Number of prediction (cached) steps to take between two full computations. That is, once a module state is + refreshed, it will be reused for `predict_steps` subsequent denoising steps, then a new full forward will + be computed on the next step. stop_predicts (`int`, *optional*, defaults to `None`): - Denoising step index at which caching is disabled. - If provided, for `self.current_step >= stop_predicts` all modules are - evaluated normally (no predictions, no state updates). + Denoising step index at which caching is disabled. If provided, for `self.current_step >= stop_predicts` + all modules are evaluated normally (no predictions, no state updates). max_order (`int`, defaults to `1`): - Maximum order of Taylor series expansion to approximate the - features. Higher order gives closer approximation but more compute. + Maximum order of Taylor series expansion to approximate the features. Higher order gives closer + approximation but more compute. taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type for computing Taylor series expansion factors. - Use lower precision to reduce memory usage. - Use higher precision to improve numerical stability. + Data type for computing Taylor series expansion factors. Use lower precision to reduce memory usage. Use + higher precision to improve numerical stability. skip_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (fullmatch) for module names to be placed in - "skip" mode, where the module is evaluated during warmup / - refresh, but then replaced by a cheap dummy tensor during - prediction steps. + Regex patterns (fullmatch) for module names to be placed in "skip" mode, where the module is evaluated + during warmup / refresh, but then replaced by a cheap dummy tensor during prediction steps. cache_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (fullmatch) for module names to be placed in - Taylor-series caching mode. + Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode. lite (`bool`, *optional*, defaults to `False`): - Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides - any user-provided `skip_identifiers` or `cache_identifiers` patterns. + Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides any user-provided + `skip_identifiers` or `cache_identifiers` patterns. Notes: - Patterns are applied with `re.fullmatch` on `module_name`. - - If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least - one of those patterns will be hooked. + - If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those + patterns will be hooked. - If neither is provided, all attention-like modules will be hooked. """ @@ -255,13 +246,13 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi ```python >>> import torch >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig - >>> + >>> pipe = FluxPipeline.from_pretrained( ... "black-forest-labs/FLUX.1-dev", ... torch_dtype=torch.bfloat16, ... ) >>> pipe.to("cuda") - >>> + >>> config = TaylorSeerCacheConfig( ... predict_steps=5, ... max_order=1, diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 56de1e646310..f4ad1af278f5 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -93,7 +93,13 @@ def enable_cache(self, config) -> None: self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig + from ..hooks import ( + FasterCacheConfig, + FirstBlockCacheConfig, + HookRegistry, + PyramidAttentionBroadcastConfig, + TaylorSeerCacheConfig, + ) from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK From 24267c76defd5ca8bc90a8ce3d974a48d6411d1d Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 28 Nov 2025 07:23:01 +0000 Subject: [PATCH 14/29] chores: naming, remove redundancy --- src/diffusers/hooks/taylorseer_cache.py | 301 ++++++++++++------------ 1 file changed, 154 insertions(+), 147 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 17d102f589c6..df37f9251c13 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -29,64 +29,74 @@ class TaylorSeerCacheConfig: Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923 Attributes: - warmup_steps (`int`, defaults to `3`): - Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor - series factors are still updated, but no predictions are used. + cache_interval (`int`, defaults to `5`): + The interval between full computation steps. After a full computation, the cached (predicted) outputs are reused + for this many subsequent denoising steps before refreshing with a new full forward pass. - predict_steps (`int`, defaults to `5`): - Number of prediction (cached) steps to take between two full computations. That is, once a module state is - refreshed, it will be reused for `predict_steps` subsequent denoising steps, then a new full forward will - be computed on the next step. + disable_cache_before_step (`int`, defaults to `3`): + The denoising step index before which caching is disabled, meaning full computation is performed for the initial + steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During these steps, + Taylor factors are updated, but caching/predictions are not applied. Caching begins at this step. - stop_predicts (`int`, *optional*, defaults to `None`): - Denoising step index at which caching is disabled. If provided, for `self.current_step >= stop_predicts` - all modules are evaluated normally (no predictions, no state updates). + disable_cache_after_step (`int`, *optional*, defaults to `None`): + The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run full + computations without predictions or state updates, ensuring accuracy in later stages if needed. max_order (`int`, defaults to `1`): - Maximum order of Taylor series expansion to approximate the features. Higher order gives closer - approximation but more compute. + The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide better + approximations but increase computation and memory usage. taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type for computing Taylor series expansion factors. Use lower precision to reduce memory usage. Use - higher precision to improve numerical stability. + Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect + stability; higher precision improves accuracy at the cost of more memory. - skip_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (fullmatch) for module names to be placed in "skip" mode, where the module is evaluated - during warmup / refresh, but then replaced by a cheap dummy tensor during prediction steps. + inactive_identifiers (`List[str]`, *optional*, defaults to `None`): + Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the module + computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during + prediction steps to skip computation cheaply. - cache_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode. + active_identifiers (`List[str]`, *optional*, defaults to `None`): + Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs + are approximated and cached for reuse. + + use_lite_mode (`bool`, *optional*, defaults to `False`): + Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for + skipping and caching (e.g., skipping blocks and caching projections). This overrides any custom + `inactive_identifiers` or `active_identifiers`. - lite (`bool`, *optional*, defaults to `False`): - Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides any user-provided - `skip_identifiers` or `cache_identifiers` patterns. Notes: - - Patterns are applied with `re.fullmatch` on `module_name`. - - If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those - patterns will be hooked. - - If neither is provided, all attention-like modules will be hooked. + - Patterns are matched using `re.fullmatch` on the module name. + - If `inactive_identifiers` or `active_identifiers` are provided, only matching modules are hooked. + - If neither is provided, all attention-like modules are hooked by default. + - Example of inactive and active usage: + ``` + def forward(x): + x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute + x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps + return x + ``` """ - warmup_steps: int = 3 - predict_steps: int = 5 - stop_predicts: Optional[int] = None + cache_interval: int = 5 + disable_cache_before_step: int = 3 + disable_cache_after_step: Optional[int] = None max_order: int = 1 taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16 - skip_identifiers: Optional[List[str]] = None - cache_identifiers: Optional[List[str]] = None - lite: bool = False + inactive_identifiers: Optional[List[str]] = None + active_identifiers: Optional[List[str]] = None + use_lite_mode: bool = False def __repr__(self) -> str: return ( "TaylorSeerCacheConfig(" - f"warmup_steps={self.warmup_steps}, " - f"predict_steps={self.predict_steps}, " - f"stop_predicts={self.stop_predicts}, " + f"cache_interval={self.cache_interval}, " + f"disable_cache_before_step={self.disable_cache_before_step}, " + f"disable_cache_after_step={self.disable_cache_after_step}, " f"max_order={self.max_order}, " f"taylor_factors_dtype={self.taylor_factors_dtype}, " - f"skip_identifiers={self.skip_identifiers}, " - f"cache_identifiers={self.cache_identifiers}, " - f"lite={self.lite})" + f"inactive_identifiers={self.inactive_identifiers}, " + f"active_identifiers={self.active_identifiers}, " + f"use_lite_mode={self.use_lite_mode})" ) @@ -95,70 +105,88 @@ def __init__( self, taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16, max_order: int = 1, + is_inactive: bool = False, ): self.taylor_factors_dtype = taylor_factors_dtype self.max_order = max_order + self.is_inactive = is_inactive self.module_dtypes: Tuple[torch.dtype, ...] = () self.last_update_step: Optional[int] = None self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {} - - # For skip-mode modules + self.inactive_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None self.device: Optional[torch.device] = None - self.dummy_tensors: Optional[Tuple[torch.Tensor, ...]] = None + self.current_step: int = -1 - self.current_step = -1 def reset(self) -> None: + self.current_step = -1 self.last_update_step = None self.taylor_factors = {} + self.inactive_shapes = None self.device = None - self.dummy_tensors = None - self.current_step = -1 def update( self, outputs: Tuple[torch.Tensor, ...], - current_step: int, ) -> None: self.module_dtypes = tuple(output.dtype for output in outputs) - for i in range(len(outputs)): - features = outputs[i].to(self.taylor_factors_dtype) - new_factors: Dict[int, torch.Tensor] = {0: features} - is_first_update = self.last_update_step is None - if not is_first_update: - delta_step = current_step - self.last_update_step - if delta_step == 0: - raise ValueError("Delta step cannot be zero for TaylorSeer update.") - - # Recursive divided differences up to max_order - for j in range(self.max_order): - prev = self.taylor_factors[i].get(j) - if prev is None: - break - new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step - self.taylor_factors[i] = new_factors - self.last_update_step = current_step - - def predict(self, current_step: int) -> torch.Tensor: + self.device = outputs[0].device + + if self.is_inactive: + self.inactive_shapes = tuple(output.shape for output in outputs) + else: + self.taylor_factors = {} + for i, output in enumerate(outputs): + features = output.to(self.taylor_factors_dtype) + new_factors: Dict[int, torch.Tensor] = {0: features} + is_first_update = self.last_update_step is None + if not is_first_update: + delta_step = self.current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for TaylorSeer update.") + + # Recursive divided differences up to max_order + prev_factors = self.taylor_factors.get(i, {}) + for j in range(self.max_order): + prev = prev_factors.get(j) + if prev is None: + break + new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step + self.taylor_factors[i] = new_factors + + self.last_update_step = self.current_step + + def predict(self) -> List[torch.Tensor]: if self.last_update_step is None: raise ValueError("Cannot predict without prior initialization/update.") - step_offset = current_step - self.last_update_step - - if not self.taylor_factors: - raise ValueError("Taylor factors empty during prediction.") + step_offset = self.current_step - self.last_update_step outputs = [] - for i in range(len(self.module_dtypes)): - taylor_factors = self.taylor_factors[i] - # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) - output = torch.zeros_like(taylor_factors[0]) - for order, factor in taylor_factors.items(): - # Note: order starts at 0 - coeff = (step_offset**order) / math.factorial(order) - output = output + factor * coeff - outputs.append(output.to(self.module_dtypes[i])) + if self.is_inactive: + if self.inactive_shapes is None: + raise ValueError("Inactive shapes not set during prediction.") + for i in range(len(self.module_dtypes)): + outputs.append( + torch.zeros( + self.inactive_shapes[i], + dtype=self.module_dtypes[i], + device=self.device, + ) + ) + else: + if not self.taylor_factors: + raise ValueError("Taylor factors empty during prediction.") + for i in range(len(self.module_dtypes)): + taylor_factors = self.taylor_factors[i] + # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) + output = torch.zeros_like(taylor_factors[0]) + for order, factor in taylor_factors.items(): + # Note: order starts at 0 + coeff = (step_offset**order) / math.factorial(order) + output = output + factor * coeff + outputs.append(output.to(self.module_dtypes[i])) return outputs @@ -168,24 +196,18 @@ class TaylorSeerCacheHook(ModelHook): def __init__( self, - module_name: str, - predict_steps: int, - warmup_steps: int, + cache_interval: int, + disable_cache_before_step: int, taylor_factors_dtype: torch.dtype, state_manager: StateManager, - stop_predicts: Optional[int] = None, - is_skip: bool = False, + disable_cache_after_step: Optional[int] = None, ): super().__init__() - self.module_name = module_name - self.predict_steps = predict_steps - self.warmup_steps = warmup_steps - self.stop_predicts = stop_predicts + self.cache_interval = cache_interval + self.disable_cache_before_step = disable_cache_before_step + self.disable_cache_after_step = disable_cache_after_step self.taylor_factors_dtype = taylor_factors_dtype self.state_manager = state_manager - self.is_skip = is_skip - - self.dummy_outputs = None def initialize_hook(self, module: torch.nn.Module): return module @@ -194,50 +216,48 @@ def reset_state(self, module: torch.nn.Module) -> None: """ Reset state between sampling runs. """ - self.dummy_outputs = None - self.current_step = -1 self.state_manager.reset() def new_forward(self, module: torch.nn.Module, *args, **kwargs): state: TaylorSeerState = self.state_manager.get_state() state.current_step += 1 current_step = state.current_step - is_warmup_phase = current_step < self.warmup_steps + is_warmup_phase = current_step < self.disable_cache_before_step + is_compute_interval = ((current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0) + is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step should_compute = ( is_warmup_phase - or ((current_step - self.warmup_steps - 1) % self.predict_steps == 0) - or (self.stop_predicts is not None and current_step >= self.stop_predicts) + or is_compute_interval + or is_cooldown_phase ) if should_compute: outputs = self.fn_ref.original_forward(*args, **kwargs) - if not self.is_skip: - state.update((outputs,) if isinstance(outputs, torch.Tensor) else outputs, current_step) - else: - self.dummy_outputs = outputs + wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs + state.update(wrapped_outputs) return outputs - if self.is_skip: - return self.dummy_outputs - - outputs = state.predict(current_step) - return outputs[0] if len(outputs) == 1 else outputs + outputs_list = state.predict() + return outputs_list[0] if len(outputs_list) == 1 else tuple(outputs_list) def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]: """ - Resolve effective skip and cache pattern lists from config + templates. + Resolve effective inactive and active pattern lists from config + templates. """ - skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else None - cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else None + inactive_patterns = config.inactive_identifiers if config.inactive_identifiers is not None else None + active_patterns = config.active_identifiers if config.active_identifiers is not None else None - return skip_patterns or [], cache_patterns or [] + return inactive_patterns or [], active_patterns or [] def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): """ Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet). + This function hooks selected modules in the model to enable caching or skipping based on the provided configuration, + reducing redundant computations in diffusion denoising loops. + Args: module (torch.nn.Module): The model subtree to apply the hooks to. config (TaylorSeerCacheConfig): Configuration for the cache. @@ -254,60 +274,41 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi >>> pipe.to("cuda") >>> config = TaylorSeerCacheConfig( - ... predict_steps=5, + ... cache_interval=5, ... max_order=1, - ... warmup_steps=3, + ... disable_cache_before_step=3, ... taylor_factors_dtype=torch.float32, ... ) >>> pipe.transformer.enable_cache(config) ``` """ - skip_patterns, cache_patterns = _resolve_patterns(config) + inactive_patterns, active_patterns = _resolve_patterns(config) - logger.debug("TaylorSeer skip identifiers: %s", skip_patterns) - logger.debug("TaylorSeer cache identifiers: %s", cache_patterns) + active_patterns = active_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS - cache_patterns = cache_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS - - if config.lite: + if config.use_lite_mode: logger.info("Using TaylorSeer Lite variant for cache.") - cache_patterns = _PROJ_OUT_IDENTIFIERS - skip_patterns = _BLOCK_IDENTIFIERS - if config.skip_identifiers or config.cache_identifiers: + active_patterns = _PROJ_OUT_IDENTIFIERS + inactive_patterns = _BLOCK_IDENTIFIERS + if config.inactive_identifiers or config.active_identifiers: logger.warning("Lite mode overrides user patterns.") for name, submodule in module.named_modules(): - matches_skip = any(re.fullmatch(pattern, name) for pattern in skip_patterns) - matches_cache = any(re.fullmatch(pattern, name) for pattern in cache_patterns) - if not (matches_skip or matches_cache): + matches_inactive = any(re.fullmatch(pattern, name) for pattern in inactive_patterns) + matches_active = any(re.fullmatch(pattern, name) for pattern in active_patterns) + if not (matches_inactive or matches_active): continue - logger.debug( - "Applying TaylorSeer cache to %s (mode=%s)", - name, - "skip" if matches_skip else "cache", - ) - state_manager = StateManager( - TaylorSeerState, - init_kwargs={ - "taylor_factors_dtype": config.taylor_factors_dtype, - "max_order": config.max_order, - }, - ) _apply_taylorseer_cache_hook( - name=name, module=submodule, config=config, - is_skip=matches_skip, - state_manager=state_manager, + is_inactive=matches_inactive, ) def _apply_taylorseer_cache_hook( - name: str, module: nn.Module, config: TaylorSeerCacheConfig, - is_skip: bool, - state_manager: StateManager, + is_inactive: bool, ): """ Registers the TaylorSeer hook on the specified nn.Module. @@ -316,19 +317,25 @@ def _apply_taylorseer_cache_hook( name: Name of the module. module: The nn.Module to be hooked. config: Cache configuration. - is_skip: Whether this module should operate in "skip" mode. - state_manager: The state manager for managing hook state. + is_inactive: Whether this module should operate in "inactive" mode. """ + state_manager = StateManager( + TaylorSeerState, + init_kwargs={ + "taylor_factors_dtype": config.taylor_factors_dtype, + "max_order": config.max_order, + "is_inactive": is_inactive, + }, + ) + registry = HookRegistry.check_if_exists_or_initialize(module) hook = TaylorSeerCacheHook( - module_name=name, - predict_steps=config.predict_steps, - warmup_steps=config.warmup_steps, + cache_interval=config.cache_interval, + disable_cache_before_step=config.disable_cache_before_step, taylor_factors_dtype=config.taylor_factors_dtype, - stop_predicts=config.stop_predicts, - is_skip=is_skip, + disable_cache_after_step=config.disable_cache_after_step, state_manager=state_manager, ) - registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK) + registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK) \ No newline at end of file From 83b62531f881bdde885caf806677d5b81b635480 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 28 Nov 2025 07:23:06 +0000 Subject: [PATCH 15/29] add docs --- docs/source/en/optimization/cache.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docs/source/en/optimization/cache.md b/docs/source/en/optimization/cache.md index 881529b27ff1..a73266618ebb 100644 --- a/docs/source/en/optimization/cache.md +++ b/docs/source/en/optimization/cache.md @@ -66,4 +66,31 @@ config = FasterCacheConfig( tensor_format="BFCHW", ) pipeline.transformer.enable_cache(config) +``` + +## TaylorSeer Cache + +[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. This method predicts future outputs based on past computations, reusing them over specified intervals to reduce redundant calculations. + +It supports selective module skipping (inactive mode), where certain modules return zero tensors during prediction steps to skip computations cheaply, and a lightweight "lite" mode for optimized memory usage with predefined patterns for skipping and caching. + +Set up and pass a [`TaylorSeerCacheConfig`] to a pipeline's transformer to enable it. The `cache_interval` controls how many steps to reuse cached outputs before refreshing with a full forward pass. The `disable_cache_before_step` specifies the initial steps where full computations are performed to gather data for approximations. Higher `max_order` improves approximation accuracy but increases memory usage. + +```python +import torch +from diffusers import FluxPipeline, TaylorSeerCacheConfig + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +) +pipe.to("cuda") + +config = TaylorSeerCacheConfig( + cache_interval=5, + max_order=1, + disable_cache_before_step=10, + taylor_factors_dtype=torch.bfloat16, +) +pipe.transformer.enable_cache(config) ``` \ No newline at end of file From 309ce72140f1e8069b27145fa15ddcd492b71f30 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 28 Nov 2025 07:28:44 +0000 Subject: [PATCH 16/29] quality & style --- src/diffusers/hooks/taylorseer_cache.py | 48 ++++++++++++------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index df37f9251c13..e1e0bebcf08a 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -30,34 +30,35 @@ class TaylorSeerCacheConfig: Attributes: cache_interval (`int`, defaults to `5`): - The interval between full computation steps. After a full computation, the cached (predicted) outputs are reused - for this many subsequent denoising steps before refreshing with a new full forward pass. + The interval between full computation steps. After a full computation, the cached (predicted) outputs are + reused for this many subsequent denoising steps before refreshing with a new full forward pass. disable_cache_before_step (`int`, defaults to `3`): - The denoising step index before which caching is disabled, meaning full computation is performed for the initial - steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During these steps, - Taylor factors are updated, but caching/predictions are not applied. Caching begins at this step. + The denoising step index before which caching is disabled, meaning full computation is performed for the + initial steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During + these steps, Taylor factors are updated, but caching/predictions are not applied. Caching begins at this + step. disable_cache_after_step (`int`, *optional*, defaults to `None`): - The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run full - computations without predictions or state updates, ensuring accuracy in later stages if needed. + The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run + full computations without predictions or state updates, ensuring accuracy in later stages if needed. max_order (`int`, defaults to `1`): - The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide better - approximations but increase computation and memory usage. + The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide + better approximations but increase computation and memory usage. taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect - stability; higher precision improves accuracy at the cost of more memory. + Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may + affect stability; higher precision improves accuracy at the cost of more memory. inactive_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the module - computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during - prediction steps to skip computation cheaply. + Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the + module computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) + during prediction steps to skip computation cheaply. active_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs - are approximated and cached for reuse. + Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where + outputs are approximated and cached for reuse. use_lite_mode (`bool`, *optional*, defaults to `False`): Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for @@ -118,7 +119,6 @@ def __init__( self.device: Optional[torch.device] = None self.current_step: int = -1 - def reset(self) -> None: self.current_step = -1 self.last_update_step = None @@ -223,13 +223,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): state.current_step += 1 current_step = state.current_step is_warmup_phase = current_step < self.disable_cache_before_step - is_compute_interval = ((current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0) + is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0 is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step - should_compute = ( - is_warmup_phase - or is_compute_interval - or is_cooldown_phase - ) + should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase if should_compute: outputs = self.fn_ref.original_forward(*args, **kwargs) wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs @@ -255,8 +251,8 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi """ Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet). - This function hooks selected modules in the model to enable caching or skipping based on the provided configuration, - reducing redundant computations in diffusion denoising loops. + This function hooks selected modules in the model to enable caching or skipping based on the provided + configuration, reducing redundant computations in diffusion denoising loops. Args: module (torch.nn.Module): The model subtree to apply the hooks to. @@ -338,4 +334,4 @@ def _apply_taylorseer_cache_hook( state_manager=state_manager, ) - registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK) \ No newline at end of file + registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK) From d06c6bc6c2a84a3db8b41542357cf488869f9c19 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 28 Nov 2025 08:14:41 +0000 Subject: [PATCH 17/29] fix taylor precision --- src/diffusers/hooks/taylorseer_cache.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index e1e0bebcf08a..3554255744c7 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -137,8 +137,7 @@ def update( self.inactive_shapes = tuple(output.shape for output in outputs) else: self.taylor_factors = {} - for i, output in enumerate(outputs): - features = output.to(self.taylor_factors_dtype) + for i, features in enumerate(outputs): new_factors: Dict[int, torch.Tensor] = {0: features} is_first_update = self.last_update_step is None if not is_first_update: @@ -152,8 +151,8 @@ def update( prev = prev_factors.get(j) if prev is None: break - new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step - self.taylor_factors[i] = new_factors + new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step + self.taylor_factors[i] = {order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items()} self.last_update_step = self.current_step @@ -179,14 +178,15 @@ def predict(self) -> List[torch.Tensor]: if not self.taylor_factors: raise ValueError("Taylor factors empty during prediction.") for i in range(len(self.module_dtypes)): + output_dtype = self.module_dtypes[i] taylor_factors = self.taylor_factors[i] # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) - output = torch.zeros_like(taylor_factors[0]) + output = torch.zeros_like(taylor_factors[0], dtype=output_dtype) for order, factor in taylor_factors.items(): # Note: order starts at 0 coeff = (step_offset**order) / math.factorial(order) - output = output + factor * coeff - outputs.append(output.to(self.module_dtypes[i])) + output = output + factor.to(output_dtype) * coeff + outputs.append(output) return outputs From 716dfe1468b024a08e440f221aab67b3d75ca91c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 28 Nov 2025 12:52:07 +0000 Subject: [PATCH 18/29] Apply style fixes --- src/diffusers/hooks/taylorseer_cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 3554255744c7..4b9f4dd47996 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -152,7 +152,9 @@ def update( if prev is None: break new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step - self.taylor_factors[i] = {order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items()} + self.taylor_factors[i] = { + order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items() + } self.last_update_step = self.current_step From e2dae7e43282691bfc3a8b2a0f1ec3fc19b1fe9b Mon Sep 17 00:00:00 2001 From: toilaluan Date: Sat, 29 Nov 2025 07:21:01 +0000 Subject: [PATCH 19/29] add tests --- tests/pipelines/flux/test_pipeline_flux.py | 2 + .../flux/test_pipeline_flux_kontext.py | 2 + .../test_pipeline_flux_kontext_inpaint.py | 2 + tests/pipelines/flux2/test_pipeline_flux2.py | 3 +- .../hunyuan_video/test_hunyuan_video.py | 2 + tests/pipelines/test_pipelines_common.py | 51 +++++++++++++++++++ 6 files changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 1ddbd4ba3df8..74499bfa607a 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -29,6 +29,7 @@ FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, + TaylorSeerCacheTesterMixin, check_qkv_fused_layers_exist, ) @@ -39,6 +40,7 @@ class FluxPipelineFastTests( PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, + TaylorSeerCacheTesterMixin, unittest.TestCase, ): pipeline_class = FluxPipeline diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext.py b/tests/pipelines/flux/test_pipeline_flux_kontext.py index 5c78964ea54f..06a9f1346cda 100644 --- a/tests/pipelines/flux/test_pipeline_flux_kontext.py +++ b/tests/pipelines/flux/test_pipeline_flux_kontext.py @@ -19,6 +19,7 @@ FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, + TaylorSeerCacheTesterMixin, ) @@ -28,6 +29,7 @@ class FluxKontextPipelineFastTests( FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, + TaylorSeerCacheTesterMixin, ): pipeline_class = FluxKontextPipeline params = frozenset( diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py index 9a2e32056dcb..106a8cacf276 100644 --- a/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py @@ -19,6 +19,7 @@ FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, + TaylorSeerCacheTesterMixin, ) @@ -28,6 +29,7 @@ class FluxKontextInpaintPipelineFastTests( FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, + TaylorSeerCacheTesterMixin, ): pipeline_class = FluxKontextInpaintPipeline params = frozenset( diff --git a/tests/pipelines/flux2/test_pipeline_flux2.py b/tests/pipelines/flux2/test_pipeline_flux2.py index 4404dbc51047..bb77e1943dbb 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2.py +++ b/tests/pipelines/flux2/test_pipeline_flux2.py @@ -16,11 +16,12 @@ ) from ..test_pipelines_common import ( PipelineTesterMixin, + TaylorSeerCacheTesterMixin, check_qkv_fused_layers_exist, ) -class Flux2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class Flux2PipelineFastTests(PipelineTesterMixin, TaylorSeerCacheTesterMixin, unittest.TestCase): pipeline_class = Flux2Pipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"]) batch_params = frozenset(["prompt"]) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index 4bdf3ee20e1b..57a6daebad1f 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -33,6 +33,7 @@ FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, + TaylorSeerCacheTesterMixin, to_np, ) @@ -45,6 +46,7 @@ class HunyuanVideoPipelineFastTests( PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, + TaylorSeerCacheTesterMixin, unittest.TestCase, ): pipeline_class = HunyuanVideoPipeline diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 22570b28841e..0a107dd5f909 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -36,6 +36,7 @@ from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook from diffusers.hooks.first_block_cache import FirstBlockCacheConfig from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook +from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.models.attention import AttentionModuleMixin @@ -2923,6 +2924,56 @@ def run_forward(pipe): "Outputs from normal inference and after disabling cache should not differ." ) +class TaylorSeerCacheTesterMixin: + taylorseer_cache_config = TaylorSeerCacheConfig( + cache_interval=5, + disable_cache_before_step=10, + max_order=1, + taylor_factors_dtype=torch.bfloat16, + use_lite_mode=True, + ) + + def test_taylorseer_cache_inference(self, expected_atol: float = 0.1): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 50 + return pipe(**inputs)[0] + + # Run inference without TaylorSeerCache + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) + + # Run inference with TaylorSeerCache enabled + pipe = create_pipe() + pipe.transformer.enable_cache(self.taylorseer_cache_config) + output = run_forward(pipe).flatten() + image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:])) + + # Run inference with TaylorSeerCache disabled + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:])) + + assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), ( + "TaylorSeerCache outputs should not differ much." + ) + assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), ( + "Outputs from normal inference and after disabling cache should not differ." + ) + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a From 289146e73ebe4c769ca6c98ab0abccd323d30b9a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sun, 30 Nov 2025 21:40:03 +0000 Subject: [PATCH 20/29] Apply style fixes --- tests/pipelines/test_pipelines_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 0a107dd5f909..7db5f4da89ca 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2924,6 +2924,7 @@ def run_forward(pipe): "Outputs from normal inference and after disabling cache should not differ." ) + class TaylorSeerCacheTesterMixin: taylorseer_cache_config = TaylorSeerCacheConfig( cache_interval=5, From 475ec02d8c34861f9302fcb24a706c745f16577b Mon Sep 17 00:00:00 2001 From: Tran Thanh Luan <92072154+toilaluan@users.noreply.github.com> Date: Mon, 1 Dec 2025 09:31:23 +0700 Subject: [PATCH 21/29] Remove TaylorSeerCacheTesterMixin from flux2 tests --- tests/pipelines/flux2/test_pipeline_flux2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/pipelines/flux2/test_pipeline_flux2.py b/tests/pipelines/flux2/test_pipeline_flux2.py index bb77e1943dbb..4404dbc51047 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2.py +++ b/tests/pipelines/flux2/test_pipeline_flux2.py @@ -16,12 +16,11 @@ ) from ..test_pipelines_common import ( PipelineTesterMixin, - TaylorSeerCacheTesterMixin, check_qkv_fused_layers_exist, ) -class Flux2PipelineFastTests(PipelineTesterMixin, TaylorSeerCacheTesterMixin, unittest.TestCase): +class Flux2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = Flux2Pipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"]) batch_params = frozenset(["prompt"]) From 4fb3f53b6c49da08d8d0171c8fe120daf81edcc1 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Wed, 3 Dec 2025 10:43:01 +0000 Subject: [PATCH 22/29] rename identifiers, use more expressive taylor predict loop --- src/diffusers/hooks/taylorseer_cache.py | 31 ++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 4b9f4dd47996..607d652f4a21 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -51,12 +51,12 @@ class TaylorSeerCacheConfig: Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect stability; higher precision improves accuracy at the cost of more memory. - inactive_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the + skip_predict_identifiers (`List[str]`, *optional*, defaults to `None`): + Regex patterns (using `re.fullmatch`) for module names to place as "skip" in "cache" mode. In this mode, the module computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during prediction steps to skip computation cheaply. - active_identifiers (`List[str]`, *optional*, defaults to `None`): + cache_identifiers (`List[str]`, *optional*, defaults to `None`): Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs are approximated and cached for reuse. @@ -83,8 +83,8 @@ def forward(x): disable_cache_after_step: Optional[int] = None max_order: int = 1 taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16 - inactive_identifiers: Optional[List[str]] = None - active_identifiers: Optional[List[str]] = None + skip_predict_identifiers: Optional[List[str]] = None + cache_identifiers: Optional[List[str]] = None use_lite_mode: bool = False def __repr__(self) -> str: @@ -95,8 +95,8 @@ def __repr__(self) -> str: f"disable_cache_after_step={self.disable_cache_after_step}, " f"max_order={self.max_order}, " f"taylor_factors_dtype={self.taylor_factors_dtype}, " - f"inactive_identifiers={self.inactive_identifiers}, " - f"active_identifiers={self.active_identifiers}, " + f"skip_predict_identifiers={self.skip_predict_identifiers}, " + f"cache_identifiers={self.cache_identifiers}, " f"use_lite_mode={self.use_lite_mode})" ) @@ -136,7 +136,6 @@ def update( if self.is_inactive: self.inactive_shapes = tuple(output.shape for output in outputs) else: - self.taylor_factors = {} for i, features in enumerate(outputs): new_factors: Dict[int, torch.Tensor] = {0: features} is_first_update = self.last_update_step is None @@ -179,17 +178,17 @@ def predict(self) -> List[torch.Tensor]: else: if not self.taylor_factors: raise ValueError("Taylor factors empty during prediction.") - for i in range(len(self.module_dtypes)): + num_outputs = len(self.taylor_factors) + num_orders = len(self.taylor_factors[0]) + for i in range(num_outputs): output_dtype = self.module_dtypes[i] taylor_factors = self.taylor_factors[i] - # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) output = torch.zeros_like(taylor_factors[0], dtype=output_dtype) - for order, factor in taylor_factors.items(): - # Note: order starts at 0 + for order in range(num_orders): coeff = (step_offset**order) / math.factorial(order) + factor = taylor_factors[order] output = output + factor.to(output_dtype) * coeff outputs.append(output) - return outputs @@ -243,8 +242,8 @@ def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[st Resolve effective inactive and active pattern lists from config + templates. """ - inactive_patterns = config.inactive_identifiers if config.inactive_identifiers is not None else None - active_patterns = config.active_identifiers if config.active_identifiers is not None else None + inactive_patterns = config.skip_predict_identifiers if config.skip_predict_identifiers is not None else None + active_patterns = config.cache_identifiers if config.cache_identifiers is not None else None return inactive_patterns or [], active_patterns or [] @@ -288,7 +287,7 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi logger.info("Using TaylorSeer Lite variant for cache.") active_patterns = _PROJ_OUT_IDENTIFIERS inactive_patterns = _BLOCK_IDENTIFIERS - if config.inactive_identifiers or config.active_identifiers: + if config.skip_predict_identifiers or config.cache_identifiers: logger.warning("Lite mode overrides user patterns.") for name, submodule in module.named_modules(): From 76494ca09897979b90384e91978f4949f03df55a Mon Sep 17 00:00:00 2001 From: toilaluan Date: Thu, 4 Dec 2025 06:12:37 +0000 Subject: [PATCH 23/29] torch compile compatible --- src/diffusers/hooks/taylorseer_cache.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 607d652f4a21..bbca345e8b0b 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -157,6 +157,7 @@ def update( self.last_update_step = self.current_step + @torch.compiler.disable def predict(self) -> List[torch.Tensor]: if self.last_update_step is None: raise ValueError("Cannot predict without prior initialization/update.") @@ -219,7 +220,8 @@ def reset_state(self, module: torch.nn.Module) -> None: """ self.state_manager.reset() - def new_forward(self, module: torch.nn.Module, *args, **kwargs): + @torch.compiler.disable + def _measure_should_compute(self) -> bool: state: TaylorSeerState = self.state_manager.get_state() state.current_step += 1 current_step = state.current_step @@ -227,6 +229,10 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0 is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase + return should_compute, state + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + should_compute, state = self._measure_should_compute() if should_compute: outputs = self.fn_ref.original_forward(*args, **kwargs) wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs From ca24569a2a069c791216490dc2cac05cd3f8f2ca Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 4 Dec 2025 11:16:52 +0000 Subject: [PATCH 24/29] Apply style fixes --- src/diffusers/hooks/taylorseer_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index bbca345e8b0b..969f43f70e7e 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -52,9 +52,9 @@ class TaylorSeerCacheConfig: affect stability; higher precision improves accuracy at the cost of more memory. skip_predict_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (using `re.fullmatch`) for module names to place as "skip" in "cache" mode. In this mode, the - module computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) - during prediction steps to skip computation cheaply. + Regex patterns (using `re.fullmatch`) for module names to place as "skip" in "cache" mode. In this mode, + the module computes fully during initial or refresh steps but returns a zero tensor (matching recorded + shape) during prediction steps to skip computation cheaply. cache_identifiers (`List[str]`, *optional*, defaults to `None`): Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where From d009d451c29d03969556813d48bb24efb1d666f1 Mon Sep 17 00:00:00 2001 From: Tran Thanh Luan <92072154+toilaluan@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:26:25 +0700 Subject: [PATCH 25/29] Update src/diffusers/hooks/taylorseer_cache.py Co-authored-by: Dhruv Nair --- src/diffusers/hooks/taylorseer_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 969f43f70e7e..3c5a606bd2ed 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -67,7 +67,7 @@ class TaylorSeerCacheConfig: Notes: - Patterns are matched using `re.fullmatch` on the module name. - - If `inactive_identifiers` or `active_identifiers` are provided, only matching modules are hooked. + - If `skip_predict_identifiers` or `cache_identifiers` are provided, only matching modules are hooked. - If neither is provided, all attention-like modules are hooked by default. - Example of inactive and active usage: ``` From 5229769a94599da35f595a944b4e8e5cd80c4972 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 5 Dec 2025 07:31:57 +0000 Subject: [PATCH 26/29] update docs --- docs/source/en/optimization/cache.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source/en/optimization/cache.md b/docs/source/en/optimization/cache.md index a73266618ebb..6397c7d4cd2e 100644 --- a/docs/source/en/optimization/cache.md +++ b/docs/source/en/optimization/cache.md @@ -70,11 +70,15 @@ pipeline.transformer.enable_cache(config) ## TaylorSeer Cache -[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. This method predicts future outputs based on past computations, reusing them over specified intervals to reduce redundant calculations. +[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations. -It supports selective module skipping (inactive mode), where certain modules return zero tensors during prediction steps to skip computations cheaply, and a lightweight "lite" mode for optimized memory usage with predefined patterns for skipping and caching. +This caching mechanism delivers strong results with minimal additional memory overhead. For detailed performance analysis, see [our findings here](https://github.com/huggingface/diffusers/pull/12648#issuecomment-3610615080). -Set up and pass a [`TaylorSeerCacheConfig`] to a pipeline's transformer to enable it. The `cache_interval` controls how many steps to reuse cached outputs before refreshing with a full forward pass. The `disable_cache_before_step` specifies the initial steps where full computations are performed to gather data for approximations. Higher `max_order` improves approximation accuracy but increases memory usage. +To enable TaylorSeer Cache, create a [`TaylorSeerCacheConfig`] and pass it to your pipeline's transformer: + +- `cache_interval`: Number of steps to reuse cached outputs before performing a full forward pass +- `disable_cache_before_step`: Initial steps that use full computations to gather data for approximations +- `max_order`: Approximation accuracy (in theory, higher values improve quality but increase memory usage but we recommend it should be set to `1`) ```python import torch From a9c59b7de6bc3dd1bc71764162450ba7092d8df5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Dec 2025 20:31:31 +0700 Subject: [PATCH 27/29] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6be7618fcd5e..8628893200fe 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -257,6 +257,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class TaylorSeerCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def apply_faster_cache(*args, **kwargs): requires_backends(apply_faster_cache, ["torch"]) @@ -273,6 +288,10 @@ def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) +def apply_taylorseer_cache(*args, **kwargs): + requires_backends(apply_taylorseer_cache, ["torch"]) + + class AllegroTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] From 9a91821bd2f1cb7ae92d8a50a26213bb84f38b96 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Dec 2025 20:34:09 +0700 Subject: [PATCH 28/29] fix example usage. --- src/diffusers/hooks/taylorseer_cache.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 3c5a606bd2ed..7cad9f4fa161 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -69,13 +69,15 @@ class TaylorSeerCacheConfig: - Patterns are matched using `re.fullmatch` on the module name. - If `skip_predict_identifiers` or `cache_identifiers` are provided, only matching modules are hooked. - If neither is provided, all attention-like modules are hooked by default. - - Example of inactive and active usage: - ``` - def forward(x): - x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute - x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps - return x - ``` + + Example of inactive and active usage: + + ```py + def forward(x): + x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute + x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps + return x + ``` """ cache_interval: int = 5 From 46d7bdbcb13625bd8063466f5b9649d1da099c18 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 5 Dec 2025 21:22:03 +0700 Subject: [PATCH 29/29] remove tests on flux kontext --- tests/pipelines/flux/test_pipeline_flux_kontext.py | 2 -- tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext.py b/tests/pipelines/flux/test_pipeline_flux_kontext.py index 06a9f1346cda..5c78964ea54f 100644 --- a/tests/pipelines/flux/test_pipeline_flux_kontext.py +++ b/tests/pipelines/flux/test_pipeline_flux_kontext.py @@ -19,7 +19,6 @@ FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, - TaylorSeerCacheTesterMixin, ) @@ -29,7 +28,6 @@ class FluxKontextPipelineFastTests( FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, - TaylorSeerCacheTesterMixin, ): pipeline_class = FluxKontextPipeline params = frozenset( diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py index 106a8cacf276..9a2e32056dcb 100644 --- a/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py @@ -19,7 +19,6 @@ FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, - TaylorSeerCacheTesterMixin, ) @@ -29,7 +28,6 @@ class FluxKontextInpaintPipelineFastTests( FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, - TaylorSeerCacheTesterMixin, ): pipeline_class = FluxKontextInpaintPipeline params = frozenset(