diff --git a/src/pruna/algorithms/caching/deepcache.py b/src/pruna/algorithms/caching/deepcache.py index 26d4c2b8..63b2de9e 100644 --- a/src/pruna/algorithms/caching/deepcache.py +++ b/src/pruna/algorithms/caching/deepcache.py @@ -80,12 +80,12 @@ def model_check_fn(self, model: Any) -> bool: def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ - Apply the step caching algorithm to the model. + Apply the deepcache algorithm to the model. Parameters ---------- model : Any - The model to apply the step caching algorithm to. + The model to apply the deepcache algorithm to. smash_config : SmashConfigPrefixWrapper The configuration for the caching. diff --git a/src/pruna/algorithms/caching/fastercache.py b/src/pruna/algorithms/caching/fastercache.py new file mode 100644 index 00000000..754b8043 --- /dev/null +++ b/src/pruna/algorithms/caching/fastercache.py @@ -0,0 +1,210 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Dict, Optional, Tuple + +from ConfigSpace import OrdinalHyperparameter + +from pruna.algorithms.caching import PrunaCacher +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import ( + is_allegro_pipeline, + is_cogvideo_pipeline, + is_flux_pipeline, + is_hunyuan_pipeline, + is_latte_pipeline, + is_mochi_pipeline, + is_wan_pipeline, +) +from pruna.logging.logger import pruna_logger + + +class FasterCacheCacher(PrunaCacher): + """ + Implement FasterCache. + + FasterCache is a method that speeds up inference in diffusion transformers by: + - Reusing attention states between successive inference steps, due to high similarity between them + - Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between + unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional + branch output using the conditional branch output + """ + + algorithm_name = "fastercache" + references = {"GitHub": "https://github.com/Vchitect/FasterCache", "Paper": "https://arxiv.org/abs/2410.19355"} + tokenizer_required = False + processor_required = False + dataset_required = False + run_on_cpu = True + run_on_cuda = True + compatible_algorithms = dict(quantizer=["hqq_diffusers", "diffusers_int8"]) + + def get_hyperparameters(self) -> list: + """ + Get the hyperparameters for the algorithm. + + Returns + ------- + list + The hyperparameters. + """ + return [ + OrdinalHyperparameter( + "interval", + sequence=[2, 3, 4, 5], + default_value=2, + meta=dict( + desc="Interval at which to cache spatial attention blocks." + "Higher is faster but might degrade quality." + ), + ) + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is a valid model for the algorithm. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a valid model for the algorithm, False otherwise. + """ + pipeline_check_fns = [ + is_allegro_pipeline, + is_cogvideo_pipeline, + is_flux_pipeline, + is_hunyuan_pipeline, + is_mochi_pipeline, + is_wan_pipeline, + ] + return any(is_pipeline(model) for is_pipeline in pipeline_check_fns) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Apply the fastercache algorithm to the model. + + Parameters + ---------- + model : Any + The model to apply the fastercache algorithm to. + smash_config : SmashConfigPrefixWrapper + The configuration for the caching. + + Returns + ------- + Any + The smashed model. + """ + imported_modules = self.import_algorithm_packages() + # set default values + temporal_attention_block_skip_range: Optional[int] = None + spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) + temporal_attention_timestep_skip_range: Optional[Tuple[int, int]] = None + low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901) + high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301) + unconditional_batch_skip_range: int = 5 + unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641) + spatial_attention_block_identifiers: Tuple[str, ...] = ( + "blocks.*attn1", + "transformer_blocks.*attn1", + "single_transformer_blocks.*attn1" + ) + temporal_attention_block_identifiers: Tuple[str, ...] = ("temporal_transformer_blocks.*attn1",) + attention_weight_callback = lambda _: 0.5 # noqa: E731 + tensor_format: str = "BFCHW" + is_guidance_distilled: bool = False + + # set configs according to https://github.com/huggingface/diffusers/pull/9562 + if is_allegro_pipeline(model): + low_frequency_weight_update_timestep_range = (99, 641) + spatial_attention_block_identifiers = ("transformer_blocks",) + elif is_cogvideo_pipeline(model): + low_frequency_weight_update_timestep_range = (99, 641) + spatial_attention_block_identifiers = ("transformer_blocks",) + attention_weight_callback = lambda _: 0.3 # noqa: E731 + elif is_flux_pipeline(model): + spatial_attention_timestep_skip_range = (-1, 961) + spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",) + tensor_format = "BCHW" + is_guidance_distilled = True + elif is_hunyuan_pipeline(model): + spatial_attention_timestep_skip_range = (99, 941) + spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",) + tensor_format = "BCFHW" + is_guidance_distilled = True + elif is_latte_pipeline(model): + temporal_attention_block_skip_range = 2 + temporal_attention_timestep_skip_range = (-1, 681) + low_frequency_weight_update_timestep_range = (99, 641) + spatial_attention_block_identifiers = ("transformer_blocks.*attn1",) + temporal_attention_block_identifiers = ("temporal_transformer_blocks",) + elif is_mochi_pipeline(model): + spatial_attention_timestep_skip_range = (-1, 981) + low_frequency_weight_update_timestep_range = (301, 961) + high_frequency_weight_update_timestep_range = (-1, 851) + unconditional_batch_skip_range = 4 + unconditional_batch_timestep_skip_range = (-1, 975) + spatial_attention_block_identifiers = ("transformer_blocks",) + attention_weight_callback = lambda _: 0.6 # noqa: E731 + elif is_wan_pipeline(model): + spatial_attention_block_identifiers = ("blocks",) + tensor_format = "BCFHW" + is_guidance_distilled = True + + fastercache_config = imported_modules["FasterCacheConfig"]( + spatial_attention_block_skip_range=smash_config["interval"], + temporal_attention_block_skip_range=temporal_attention_block_skip_range, + spatial_attention_timestep_skip_range=spatial_attention_timestep_skip_range, + temporal_attention_timestep_skip_range=temporal_attention_timestep_skip_range, + low_frequency_weight_update_timestep_range=low_frequency_weight_update_timestep_range, + high_frequency_weight_update_timestep_range=high_frequency_weight_update_timestep_range, + alpha_low_frequency=1.1, + alpha_high_frequency=1.1, + unconditional_batch_skip_range=unconditional_batch_skip_range, + unconditional_batch_timestep_skip_range=unconditional_batch_timestep_skip_range, + spatial_attention_block_identifiers=spatial_attention_block_identifiers, + temporal_attention_block_identifiers=temporal_attention_block_identifiers, + attention_weight_callback=attention_weight_callback, + tensor_format=tensor_format, + current_timestep_callback=lambda: model.current_timestep, + is_guidance_distilled=is_guidance_distilled, + ) + model.transformer.enable_cache(fastercache_config) + return model + + def import_algorithm_packages(self) -> Dict[str, Any]: + """ + Import the algorithm packages. + + Returns + ------- + Dict[str, Any] + The algorithm packages. + """ + try: + from diffusers import FasterCacheConfig + except ModuleNotFoundError: + pruna_logger.error( + "You are trying to use FasterCache, but the FasterCacheConfig can not be imported from diffusers. " + "This is likely because your diffusers version is too old." + ) + raise + + return dict(FasterCacheConfig=FasterCacheConfig) diff --git a/src/pruna/algorithms/caching/pab.py b/src/pruna/algorithms/caching/pab.py new file mode 100644 index 00000000..bedc49b0 --- /dev/null +++ b/src/pruna/algorithms/caching/pab.py @@ -0,0 +1,180 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Dict, Optional, Tuple + +from ConfigSpace import OrdinalHyperparameter + +from pruna.algorithms.caching import PrunaCacher +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import ( + is_allegro_pipeline, + is_cogvideo_pipeline, + is_flux_pipeline, + is_hunyuan_pipeline, + is_latte_pipeline, + is_mochi_pipeline, + is_wan_pipeline, +) +from pruna.logging.logger import pruna_logger + + +class PABCacher(PrunaCacher): + """ + Implement PAB. + + Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping + attention computations between successive inference steps and reusing cached attention states. + """ + + algorithm_name = "pab" + references = {"Paper": "https://arxiv.org/abs/2408.12588", "HuggingFace": "https://huggingface.co/docs/diffusers/main/api/cache#pyramid-attention-broadcast"} + tokenizer_required = False + processor_required = False + dataset_required = False + run_on_cpu = True + run_on_cuda = True + compatible_algorithms = dict(quantizer=["hqq_diffusers", "diffusers_int8"]) + + def get_hyperparameters(self) -> list: + """ + Get the hyperparameters for the algorithm. + + Returns + ------- + list + The hyperparameters. + """ + return [ + OrdinalHyperparameter( + "interval", + sequence=[2, 3, 4, 5], + default_value=2, + meta=dict( + desc="Interval at which to cache spatial attention blocks." + "Higher is faster but might degrade quality." + ), + ) + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is a valid model for the algorithm. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a valid model for the algorithm, False otherwise. + """ + pipeline_check_fns = [ + is_allegro_pipeline, + is_cogvideo_pipeline, + is_flux_pipeline, + is_hunyuan_pipeline, + is_mochi_pipeline, + is_wan_pipeline + ] + return any(is_pipeline(model) for is_pipeline in pipeline_check_fns) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Apply the PAB algorithm to the model. + + Parameters + ---------- + model : Any + The model to apply the PAB algorithm to. + smash_config : SmashConfigPrefixWrapper + The configuration for the caching. + + Returns + ------- + Any + The smashed model. + """ + imported_modules = self.import_algorithm_packages() + # set default values + temporal_attention_block_skip_range: Optional[int] = None + cross_attention_block_skip_range: Optional[int] = None + spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + spatial_attention_block_identifiers: Tuple[str, ...] = ("blocks", "transformer_blocks",) + temporal_attention_block_identifiers: Tuple[str, ...] = ("temporal_transformer_blocks",) + cross_attention_block_identifiers: Tuple[str, ...] = ("blocks", "transformer_blocks",) + + # set configs according to https://github.com/huggingface/diffusers/pull/9562 + if is_allegro_pipeline(model): + cross_attention_block_skip_range = 6 + spatial_attention_timestep_skip_range = (100, 700) + cross_attention_block_identifiers = ("transformer_blocks",) + elif is_cogvideo_pipeline(model): + spatial_attention_block_identifiers = ("transformer_blocks",) + elif is_flux_pipeline(model): + spatial_attention_timestep_skip_range = (100, 950) + spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",) + elif is_hunyuan_pipeline(model): + spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",) + elif is_latte_pipeline(model): + temporal_attention_block_skip_range = None + cross_attention_block_skip_range = None + spatial_attention_timestep_skip_range = (100, 700) + spatial_attention_block_identifiers = ("transformer_blocks",) + cross_attention_block_identifiers = ("transformer_blocks",) + elif is_mochi_pipeline(model): + spatial_attention_timestep_skip_range = (400, 987) + spatial_attention_block_identifiers = ("transformer_blocks",) + elif is_wan_pipeline(model): + spatial_attention_block_identifiers = ("blocks",) + + pab_config = imported_modules["pab_config"]( + spatial_attention_block_skip_range=smash_config["interval"], + temporal_attention_block_skip_range=temporal_attention_block_skip_range, + cross_attention_block_skip_range=cross_attention_block_skip_range, + spatial_attention_timestep_skip_range=spatial_attention_timestep_skip_range, + temporal_attention_timestep_skip_range=temporal_attention_timestep_skip_range, + cross_attention_timestep_skip_range=cross_attention_timestep_skip_range, + spatial_attention_block_identifiers=spatial_attention_block_identifiers, + temporal_attention_block_identifiers=temporal_attention_block_identifiers, + cross_attention_block_identifiers=cross_attention_block_identifiers, + current_timestep_callback=lambda: model.current_timestep + ) + model.transformer.enable_cache(pab_config) + return model + + def import_algorithm_packages(self) -> Dict[str, Any]: + """ + Import the algorithm packages. + + Returns + ------- + Dict[str, Any] + The algorithm packages. + """ + try: + from diffusers import PyramidAttentionBroadcastConfig + except ModuleNotFoundError: + pruna_logger.error( + "You are trying to use PAB, but the PyramidAttentionBroadcastConfig can not be imported from diffusers. " + "This is likely because your diffusers version is too old." + ) + raise + + return dict(pab_config=PyramidAttentionBroadcastConfig) diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index 79045686..a00da0f4 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -262,6 +262,40 @@ def is_sdxl_pipeline(model: Any) -> bool: return _check_pipeline_type(model, diffusers.pipelines.stable_diffusion_xl, "StableDiffusionXL") +def is_mochi_pipeline(model: Any) -> bool: + """ + Check if model is a Mochi pipeline. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if model is a Mochi pipeline, False otherwise. + """ + return _check_pipeline_type(model, diffusers.pipelines.mochi, "Mochi") + + +def is_cogvideo_pipeline(model: Any) -> bool: + """ + Check if model is a CogVideoX pipeline. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if model is a CogVideoX pipeline, False otherwise. + """ + return _check_pipeline_type(model, diffusers.pipelines.cogvideo, "CogVideoX") + + def is_sd_pipeline(model: Any) -> bool: """ Check if model is a Stable Diffusion pipeline. @@ -279,6 +313,23 @@ def is_sd_pipeline(model: Any) -> bool: return _check_pipeline_type(model, diffusers.pipelines.stable_diffusion, "StableDiffusion") +def is_wan_pipeline(model: Any) -> bool: + """ + Check if model is a WAN pipeline. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if model is a WAN pipeline, False otherwise. + """ + return _check_pipeline_type(model, diffusers.pipelines.wan, "Wan") + + def is_sd_3_pipeline(model: Any) -> bool: """ Check if model is a Stable Diffusion 3 pipeline. @@ -330,6 +381,40 @@ def is_sana_pipeline(model: Any) -> bool: return _check_pipeline_type(model, diffusers.pipelines.sana, "Sana") +def is_latte_pipeline(model: Any) -> bool: + """ + Check if model is a Latte pipeline. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if model is a Latte pipeline, False otherwise. + """ + return _check_pipeline_type(model, diffusers.pipelines.latte, "Latte") + + +def is_allegro_pipeline(model: Any) -> bool: + """ + Check if model is an Allegro pipeline. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if model is an Allegro pipeline, False otherwise. + """ + return _check_pipeline_type(model, diffusers.pipelines.allegro, "Allegro") + + def get_helpers(model: Any) -> List[str]: """ Retrieve a list of helper attributes from the model. diff --git a/tests/algorithms/test_combinations.py b/tests/algorithms/test_combinations.py index 2a7455fa..29022832 100644 --- a/tests/algorithms/test_combinations.py +++ b/tests/algorithms/test_combinations.py @@ -41,6 +41,10 @@ def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: ("stable_diffusion_v1_4", dict(quantizer="hqq_diffusers", compiler="torch_compile"), False), ("sana", dict(quantizer="hqq_diffusers", compiler="torch_compile"), False), ("stable_diffusion_v1_4", dict(quantizer="diffusers_int8", compiler="torch_compile", torch_compile_fullgraph=False), False), + ("flux_tiny_random", dict(cacher="pab", quantizer="hqq_diffusers"), False), + ("flux_tiny_random", dict(cacher="pab", quantizer="diffusers_int8"), False), + ("flux_tiny_random", dict(cacher="fastercache", quantizer="hqq_diffusers"), False), + ("flux_tiny_random", dict(cacher="fastercache", quantizer="diffusers_int8"), False), ], indirect=["model_fixture"], ) diff --git a/tests/algorithms/testers/caching.py b/tests/algorithms/testers/caching.py index 985bc714..a662890a 100644 --- a/tests/algorithms/testers/caching.py +++ b/tests/algorithms/testers/caching.py @@ -1,5 +1,7 @@ from pruna import PrunaModel from pruna.algorithms.caching.deepcache import DeepCacheCacher +from pruna.algorithms.caching.fastercache import FasterCacheCacher +from pruna.algorithms.caching.pab import PABCacher from .base_tester import AlgorithmTesterBase @@ -15,3 +17,29 @@ class TestDeepCache(AlgorithmTesterBase): def post_smash_hook(self, model: PrunaModel) -> None: """Hook to modify the model after smashing.""" assert hasattr(model, "deepcache_unet_helper") + + +class TestFasterCache(AlgorithmTesterBase): + """Test the fastercache algorithm.""" + + models = ["flux_tiny_random"] + reject_models = ["opt_125m"] + allow_pickle_files = False + algorithm_class = FasterCacheCacher + + def post_smash_hook(self, model: PrunaModel) -> None: + """Hook to modify the model after smashing.""" + assert model.transformer.is_cache_enabled + + +class TestPAB(AlgorithmTesterBase): + """Test the PAB algorithm.""" + + models = ["flux_tiny_random"] + reject_models = ["opt_125m"] + allow_pickle_files = False + algorithm_class = PABCacher + + def post_smash_hook(self, model: PrunaModel) -> None: + """Hook to modify the model after smashing.""" + assert model.transformer.is_cache_enabled diff --git a/tests/fixtures.py b/tests/fixtures.py index 6e02e5ec..eb3f4f2b 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -6,6 +6,7 @@ import torch from diffusers import ( DDIMPipeline, + FluxPipeline, SanaPipeline, StableDiffusion3Pipeline, StableDiffusionPipeline, @@ -143,4 +144,5 @@ def get_torchvision_model(name: str) -> tuple[Any, SmashConfig]: "ddpm-cifar10": partial(get_diffusers_model, DDIMPipeline, "google/ddpm-cifar10-32"), "smollm_135m": partial(get_automodel_transformers, "HuggingFaceTB/SmolLM2-135M"), "dummy_lambda": dummy_model, + "flux_tiny_random": partial(get_diffusers_model, FluxPipeline, "katuni4ka/tiny-random-flux"), }