From 8dded7eaec27103b1b546c5fc38b6af482a0a196 Mon Sep 17 00:00:00 2001 From: nifleisch Date: Mon, 5 May 2025 09:09:31 +0000 Subject: [PATCH 1/8] fix: correct docstring in deepcache --- src/pruna/algorithms/caching/deepcache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pruna/algorithms/caching/deepcache.py b/src/pruna/algorithms/caching/deepcache.py index 26d4c2b8..63b2de9e 100644 --- a/src/pruna/algorithms/caching/deepcache.py +++ b/src/pruna/algorithms/caching/deepcache.py @@ -80,12 +80,12 @@ def model_check_fn(self, model: Any) -> bool: def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ - Apply the step caching algorithm to the model. + Apply the deepcache algorithm to the model. Parameters ---------- model : Any - The model to apply the step caching algorithm to. + The model to apply the deepcache algorithm to. smash_config : SmashConfigPrefixWrapper The configuration for the caching. From c14a071deda69e3f168228c96d2cad363539e387 Mon Sep 17 00:00:00 2001 From: nifleisch Date: Tue, 6 May 2025 07:33:03 +0000 Subject: [PATCH 2/8] feat: add model checks --- src/pruna/engine/model_checks.py | 85 ++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index 98165f22..f33dddaa 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -262,6 +262,40 @@ def is_sdxl_pipeline(model: Any) -> bool: return _check_pipeline_type(model, diffusers.pipelines.stable_diffusion_xl, "StableDiffusionXL") +def is_mochi_pipeline(model: Any) -> bool: + """ + Check if model is a Mochi pipeline. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if model is a Mochi pipeline, False otherwise. + """ + return _check_pipeline_type(model, diffusers.pipelines.mochi, "Mochi") + + +def is_cogvideo_pipeline(model: Any) -> bool: + """ + Check if model is a CogVideoX pipeline. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if model is a CogVideoX pipeline, False otherwise. + """ + return _check_pipeline_type(model, diffusers.pipelines.cogvideo, "CogVideoX") + + def is_sd_pipeline(model: Any) -> bool: """ Check if model is a Stable Diffusion pipeline. @@ -279,6 +313,23 @@ def is_sd_pipeline(model: Any) -> bool: return _check_pipeline_type(model, diffusers.pipelines.stable_diffusion, "StableDiffusion") +def is_wan_pipeline(model: Any) -> bool: + """ + Check if model is a WAN pipeline. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if model is a WAN pipeline, False otherwise. + """ + return _check_pipeline_type(model, diffusers.pipelines.wan, "Wan") + + def is_sd_3_pipeline(model: Any) -> bool: """ Check if model is a Stable Diffusion 3 pipeline. @@ -330,6 +381,40 @@ def is_sana_pipeline(model: Any) -> bool: return _check_pipeline_type(model, diffusers.pipelines.sana, "Sana") +def is_latte_pipeline(model: Any) -> bool: + """ + Check if model is a Latte pipeline. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if model is a Latte pipeline, False otherwise. + """ + return _check_pipeline_type(model, diffusers.pipelines.latte, "Latte") + + +def is_allegro_pipeline(model: Any) -> bool: + """ + Check if model is an Allegro pipeline. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if model is an Allegro pipeline, False otherwise. + """ + return _check_pipeline_type(model, diffusers.pipelines.allegro, "Allegro") + + def is_comfy_model(model: Any) -> bool: """ Check if the model is a ComfyUI model. From cbd0e7131410eaa91f648935bc6f9f00a7172555 Mon Sep 17 00:00:00 2001 From: nifleisch Date: Mon, 5 May 2025 09:46:23 +0000 Subject: [PATCH 3/8] feat: add pyramid attention broadcast (pab) cacher --- src/pruna/algorithms/caching/pab.py | 180 ++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 src/pruna/algorithms/caching/pab.py diff --git a/src/pruna/algorithms/caching/pab.py b/src/pruna/algorithms/caching/pab.py new file mode 100644 index 00000000..bedc49b0 --- /dev/null +++ b/src/pruna/algorithms/caching/pab.py @@ -0,0 +1,180 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Dict, Optional, Tuple + +from ConfigSpace import OrdinalHyperparameter + +from pruna.algorithms.caching import PrunaCacher +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import ( + is_allegro_pipeline, + is_cogvideo_pipeline, + is_flux_pipeline, + is_hunyuan_pipeline, + is_latte_pipeline, + is_mochi_pipeline, + is_wan_pipeline, +) +from pruna.logging.logger import pruna_logger + + +class PABCacher(PrunaCacher): + """ + Implement PAB. + + Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping + attention computations between successive inference steps and reusing cached attention states. + """ + + algorithm_name = "pab" + references = {"Paper": "https://arxiv.org/abs/2408.12588", "HuggingFace": "https://huggingface.co/docs/diffusers/main/api/cache#pyramid-attention-broadcast"} + tokenizer_required = False + processor_required = False + dataset_required = False + run_on_cpu = True + run_on_cuda = True + compatible_algorithms = dict(quantizer=["hqq_diffusers", "diffusers_int8"]) + + def get_hyperparameters(self) -> list: + """ + Get the hyperparameters for the algorithm. + + Returns + ------- + list + The hyperparameters. + """ + return [ + OrdinalHyperparameter( + "interval", + sequence=[2, 3, 4, 5], + default_value=2, + meta=dict( + desc="Interval at which to cache spatial attention blocks." + "Higher is faster but might degrade quality." + ), + ) + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is a valid model for the algorithm. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a valid model for the algorithm, False otherwise. + """ + pipeline_check_fns = [ + is_allegro_pipeline, + is_cogvideo_pipeline, + is_flux_pipeline, + is_hunyuan_pipeline, + is_mochi_pipeline, + is_wan_pipeline + ] + return any(is_pipeline(model) for is_pipeline in pipeline_check_fns) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Apply the PAB algorithm to the model. + + Parameters + ---------- + model : Any + The model to apply the PAB algorithm to. + smash_config : SmashConfigPrefixWrapper + The configuration for the caching. + + Returns + ------- + Any + The smashed model. + """ + imported_modules = self.import_algorithm_packages() + # set default values + temporal_attention_block_skip_range: Optional[int] = None + cross_attention_block_skip_range: Optional[int] = None + spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + spatial_attention_block_identifiers: Tuple[str, ...] = ("blocks", "transformer_blocks",) + temporal_attention_block_identifiers: Tuple[str, ...] = ("temporal_transformer_blocks",) + cross_attention_block_identifiers: Tuple[str, ...] = ("blocks", "transformer_blocks",) + + # set configs according to https://github.com/huggingface/diffusers/pull/9562 + if is_allegro_pipeline(model): + cross_attention_block_skip_range = 6 + spatial_attention_timestep_skip_range = (100, 700) + cross_attention_block_identifiers = ("transformer_blocks",) + elif is_cogvideo_pipeline(model): + spatial_attention_block_identifiers = ("transformer_blocks",) + elif is_flux_pipeline(model): + spatial_attention_timestep_skip_range = (100, 950) + spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",) + elif is_hunyuan_pipeline(model): + spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",) + elif is_latte_pipeline(model): + temporal_attention_block_skip_range = None + cross_attention_block_skip_range = None + spatial_attention_timestep_skip_range = (100, 700) + spatial_attention_block_identifiers = ("transformer_blocks",) + cross_attention_block_identifiers = ("transformer_blocks",) + elif is_mochi_pipeline(model): + spatial_attention_timestep_skip_range = (400, 987) + spatial_attention_block_identifiers = ("transformer_blocks",) + elif is_wan_pipeline(model): + spatial_attention_block_identifiers = ("blocks",) + + pab_config = imported_modules["pab_config"]( + spatial_attention_block_skip_range=smash_config["interval"], + temporal_attention_block_skip_range=temporal_attention_block_skip_range, + cross_attention_block_skip_range=cross_attention_block_skip_range, + spatial_attention_timestep_skip_range=spatial_attention_timestep_skip_range, + temporal_attention_timestep_skip_range=temporal_attention_timestep_skip_range, + cross_attention_timestep_skip_range=cross_attention_timestep_skip_range, + spatial_attention_block_identifiers=spatial_attention_block_identifiers, + temporal_attention_block_identifiers=temporal_attention_block_identifiers, + cross_attention_block_identifiers=cross_attention_block_identifiers, + current_timestep_callback=lambda: model.current_timestep + ) + model.transformer.enable_cache(pab_config) + return model + + def import_algorithm_packages(self) -> Dict[str, Any]: + """ + Import the algorithm packages. + + Returns + ------- + Dict[str, Any] + The algorithm packages. + """ + try: + from diffusers import PyramidAttentionBroadcastConfig + except ModuleNotFoundError: + pruna_logger.error( + "You are trying to use PAB, but the PyramidAttentionBroadcastConfig can not be imported from diffusers. " + "This is likely because your diffusers version is too old." + ) + raise + + return dict(pab_config=PyramidAttentionBroadcastConfig) From 9661554873793ed38aa2de1e742a2a58bcf0cba4 Mon Sep 17 00:00:00 2001 From: nifleisch Date: Mon, 5 May 2025 09:50:32 +0000 Subject: [PATCH 4/8] feat: add fastercache cacher --- src/pruna/algorithms/caching/fastercache.py | 210 ++++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 src/pruna/algorithms/caching/fastercache.py diff --git a/src/pruna/algorithms/caching/fastercache.py b/src/pruna/algorithms/caching/fastercache.py new file mode 100644 index 00000000..754b8043 --- /dev/null +++ b/src/pruna/algorithms/caching/fastercache.py @@ -0,0 +1,210 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Dict, Optional, Tuple + +from ConfigSpace import OrdinalHyperparameter + +from pruna.algorithms.caching import PrunaCacher +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import ( + is_allegro_pipeline, + is_cogvideo_pipeline, + is_flux_pipeline, + is_hunyuan_pipeline, + is_latte_pipeline, + is_mochi_pipeline, + is_wan_pipeline, +) +from pruna.logging.logger import pruna_logger + + +class FasterCacheCacher(PrunaCacher): + """ + Implement FasterCache. + + FasterCache is a method that speeds up inference in diffusion transformers by: + - Reusing attention states between successive inference steps, due to high similarity between them + - Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between + unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional + branch output using the conditional branch output + """ + + algorithm_name = "fastercache" + references = {"GitHub": "https://github.com/Vchitect/FasterCache", "Paper": "https://arxiv.org/abs/2410.19355"} + tokenizer_required = False + processor_required = False + dataset_required = False + run_on_cpu = True + run_on_cuda = True + compatible_algorithms = dict(quantizer=["hqq_diffusers", "diffusers_int8"]) + + def get_hyperparameters(self) -> list: + """ + Get the hyperparameters for the algorithm. + + Returns + ------- + list + The hyperparameters. + """ + return [ + OrdinalHyperparameter( + "interval", + sequence=[2, 3, 4, 5], + default_value=2, + meta=dict( + desc="Interval at which to cache spatial attention blocks." + "Higher is faster but might degrade quality." + ), + ) + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is a valid model for the algorithm. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a valid model for the algorithm, False otherwise. + """ + pipeline_check_fns = [ + is_allegro_pipeline, + is_cogvideo_pipeline, + is_flux_pipeline, + is_hunyuan_pipeline, + is_mochi_pipeline, + is_wan_pipeline, + ] + return any(is_pipeline(model) for is_pipeline in pipeline_check_fns) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Apply the fastercache algorithm to the model. + + Parameters + ---------- + model : Any + The model to apply the fastercache algorithm to. + smash_config : SmashConfigPrefixWrapper + The configuration for the caching. + + Returns + ------- + Any + The smashed model. + """ + imported_modules = self.import_algorithm_packages() + # set default values + temporal_attention_block_skip_range: Optional[int] = None + spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) + temporal_attention_timestep_skip_range: Optional[Tuple[int, int]] = None + low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901) + high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301) + unconditional_batch_skip_range: int = 5 + unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641) + spatial_attention_block_identifiers: Tuple[str, ...] = ( + "blocks.*attn1", + "transformer_blocks.*attn1", + "single_transformer_blocks.*attn1" + ) + temporal_attention_block_identifiers: Tuple[str, ...] = ("temporal_transformer_blocks.*attn1",) + attention_weight_callback = lambda _: 0.5 # noqa: E731 + tensor_format: str = "BFCHW" + is_guidance_distilled: bool = False + + # set configs according to https://github.com/huggingface/diffusers/pull/9562 + if is_allegro_pipeline(model): + low_frequency_weight_update_timestep_range = (99, 641) + spatial_attention_block_identifiers = ("transformer_blocks",) + elif is_cogvideo_pipeline(model): + low_frequency_weight_update_timestep_range = (99, 641) + spatial_attention_block_identifiers = ("transformer_blocks",) + attention_weight_callback = lambda _: 0.3 # noqa: E731 + elif is_flux_pipeline(model): + spatial_attention_timestep_skip_range = (-1, 961) + spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",) + tensor_format = "BCHW" + is_guidance_distilled = True + elif is_hunyuan_pipeline(model): + spatial_attention_timestep_skip_range = (99, 941) + spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",) + tensor_format = "BCFHW" + is_guidance_distilled = True + elif is_latte_pipeline(model): + temporal_attention_block_skip_range = 2 + temporal_attention_timestep_skip_range = (-1, 681) + low_frequency_weight_update_timestep_range = (99, 641) + spatial_attention_block_identifiers = ("transformer_blocks.*attn1",) + temporal_attention_block_identifiers = ("temporal_transformer_blocks",) + elif is_mochi_pipeline(model): + spatial_attention_timestep_skip_range = (-1, 981) + low_frequency_weight_update_timestep_range = (301, 961) + high_frequency_weight_update_timestep_range = (-1, 851) + unconditional_batch_skip_range = 4 + unconditional_batch_timestep_skip_range = (-1, 975) + spatial_attention_block_identifiers = ("transformer_blocks",) + attention_weight_callback = lambda _: 0.6 # noqa: E731 + elif is_wan_pipeline(model): + spatial_attention_block_identifiers = ("blocks",) + tensor_format = "BCFHW" + is_guidance_distilled = True + + fastercache_config = imported_modules["FasterCacheConfig"]( + spatial_attention_block_skip_range=smash_config["interval"], + temporal_attention_block_skip_range=temporal_attention_block_skip_range, + spatial_attention_timestep_skip_range=spatial_attention_timestep_skip_range, + temporal_attention_timestep_skip_range=temporal_attention_timestep_skip_range, + low_frequency_weight_update_timestep_range=low_frequency_weight_update_timestep_range, + high_frequency_weight_update_timestep_range=high_frequency_weight_update_timestep_range, + alpha_low_frequency=1.1, + alpha_high_frequency=1.1, + unconditional_batch_skip_range=unconditional_batch_skip_range, + unconditional_batch_timestep_skip_range=unconditional_batch_timestep_skip_range, + spatial_attention_block_identifiers=spatial_attention_block_identifiers, + temporal_attention_block_identifiers=temporal_attention_block_identifiers, + attention_weight_callback=attention_weight_callback, + tensor_format=tensor_format, + current_timestep_callback=lambda: model.current_timestep, + is_guidance_distilled=is_guidance_distilled, + ) + model.transformer.enable_cache(fastercache_config) + return model + + def import_algorithm_packages(self) -> Dict[str, Any]: + """ + Import the algorithm packages. + + Returns + ------- + Dict[str, Any] + The algorithm packages. + """ + try: + from diffusers import FasterCacheConfig + except ModuleNotFoundError: + pruna_logger.error( + "You are trying to use FasterCache, but the FasterCacheConfig can not be imported from diffusers. " + "This is likely because your diffusers version is too old." + ) + raise + + return dict(FasterCacheConfig=FasterCacheConfig) From 327526b0f06ff51a73c46cedd97bce59ac3a1174 Mon Sep 17 00:00:00 2001 From: nifleisch Date: Mon, 5 May 2025 09:51:10 +0000 Subject: [PATCH 5/8] tests: add flux tiny random fixture --- tests/fixtures.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/fixtures.py b/tests/fixtures.py index 0e31616b..a587f7cb 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -6,6 +6,7 @@ import torch from diffusers import ( DDIMPipeline, + FluxPipeline, SanaPipeline, StableDiffusion3Pipeline, StableDiffusionPipeline, @@ -149,4 +150,5 @@ def get_torchvision_model(name: str) -> tuple[Any, SmashConfig]: "ddpm-cifar10": partial(get_diffusers_model, DDIMPipeline, "google/ddpm-cifar10-32"), "smollm_135m": partial(get_automodel_transformers, "HuggingFaceTB/SmolLM2-135M"), "dummy_lambda": dummy_model, + "flux_tiny_random": partial(get_diffusers_model, FluxPipeline, "katuni4ka/tiny-random-flux"), } From 9bba3ae3dfa5c58a058a8e94ae21e0ae66ac930a Mon Sep 17 00:00:00 2001 From: nifleisch Date: Mon, 5 May 2025 09:52:08 +0000 Subject: [PATCH 6/8] tests: add algorithms tests for pab and fastercache --- tests/algorithms/testers/caching.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/algorithms/testers/caching.py b/tests/algorithms/testers/caching.py index 985bc714..a662890a 100644 --- a/tests/algorithms/testers/caching.py +++ b/tests/algorithms/testers/caching.py @@ -1,5 +1,7 @@ from pruna import PrunaModel from pruna.algorithms.caching.deepcache import DeepCacheCacher +from pruna.algorithms.caching.fastercache import FasterCacheCacher +from pruna.algorithms.caching.pab import PABCacher from .base_tester import AlgorithmTesterBase @@ -15,3 +17,29 @@ class TestDeepCache(AlgorithmTesterBase): def post_smash_hook(self, model: PrunaModel) -> None: """Hook to modify the model after smashing.""" assert hasattr(model, "deepcache_unet_helper") + + +class TestFasterCache(AlgorithmTesterBase): + """Test the fastercache algorithm.""" + + models = ["flux_tiny_random"] + reject_models = ["opt_125m"] + allow_pickle_files = False + algorithm_class = FasterCacheCacher + + def post_smash_hook(self, model: PrunaModel) -> None: + """Hook to modify the model after smashing.""" + assert model.transformer.is_cache_enabled + + +class TestPAB(AlgorithmTesterBase): + """Test the PAB algorithm.""" + + models = ["flux_tiny_random"] + reject_models = ["opt_125m"] + allow_pickle_files = False + algorithm_class = PABCacher + + def post_smash_hook(self, model: PrunaModel) -> None: + """Hook to modify the model after smashing.""" + assert model.transformer.is_cache_enabled From 3743bfc9bb9c0a0866698d9a3a36ac2bac5cdb84 Mon Sep 17 00:00:00 2001 From: nifleisch Date: Tue, 6 May 2025 07:35:10 +0000 Subject: [PATCH 7/8] tests: add combination tests for pab and fastercache --- tests/algorithms/test_combinations.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/algorithms/test_combinations.py b/tests/algorithms/test_combinations.py index 43776aaf..c625b206 100644 --- a/tests/algorithms/test_combinations.py +++ b/tests/algorithms/test_combinations.py @@ -43,6 +43,10 @@ def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: ("stable_diffusion_v1_4", dict(quantizer="diffusers_int8", compiler="torch_compile", torch_compile_fullgraph=False), False), ("llama_3_2_1b", dict(quantizer="gptq", compiler="torch_compile"), True), ("llama_3_2_1b", dict(quantizer="llm_int8", compiler="torch_compile", torch_compile_fullgraph=False), True), + ("flux_tiny_random", dict(cacher="pab", quantizer="hqq_diffusers"), False), + ("flux_tiny_random", dict(cacher="pab", quantizer="diffusers_int8"), False), + ("flux_tiny_random", dict(cacher="fastercache", quantizer="hqq_diffusers"), False), + ("flux_tiny_random", dict(cacher="fastercache", quantizer="diffusers_int8"), False), ], indirect=["model_fixture"], ) From 3f2213f912c1b9a1ba52063424b67b32f37c1b57 Mon Sep 17 00:00:00 2001 From: nifleisch Date: Mon, 12 May 2025 15:22:27 +0000 Subject: [PATCH 8/8] fix: add 1 as value for interval parameter --- src/pruna/algorithms/caching/fastercache.py | 10 ++++++---- src/pruna/algorithms/caching/pab.py | 9 +++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/pruna/algorithms/caching/fastercache.py b/src/pruna/algorithms/caching/fastercache.py index 754b8043..af63e5a8 100644 --- a/src/pruna/algorithms/caching/fastercache.py +++ b/src/pruna/algorithms/caching/fastercache.py @@ -40,6 +40,8 @@ class FasterCacheCacher(PrunaCacher): - Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output + This implementation reduces the number of tunable parameters by setting pipeline specific parameters according to + https://github.com/huggingface/diffusers/pull/9562. """ algorithm_name = "fastercache" @@ -63,13 +65,13 @@ def get_hyperparameters(self) -> list: return [ OrdinalHyperparameter( "interval", - sequence=[2, 3, 4, 5], + sequence=[1, 2, 3, 4, 5], default_value=2, meta=dict( - desc="Interval at which to cache spatial attention blocks." + desc="Interval at which to cache spatial attention blocks - 1 disables caching." "Higher is faster but might degrade quality." ), - ) + ), ] def model_check_fn(self, model: Any) -> bool: @@ -113,7 +115,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: The smashed model. """ imported_modules = self.import_algorithm_packages() - # set default values + # set default values according to https://huggingface.co/docs/diffusers/en/api/cache temporal_attention_block_skip_range: Optional[int] = None spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) temporal_attention_timestep_skip_range: Optional[Tuple[int, int]] = None diff --git a/src/pruna/algorithms/caching/pab.py b/src/pruna/algorithms/caching/pab.py index bedc49b0..62825b46 100644 --- a/src/pruna/algorithms/caching/pab.py +++ b/src/pruna/algorithms/caching/pab.py @@ -36,7 +36,8 @@ class PABCacher(PrunaCacher): Implement PAB. Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping - attention computations between successive inference steps and reusing cached attention states. + attention computations between successive inference steps and reusing cached attention states. This implementation + reduces the number of tunable parameters by setting pipeline specific parameters according to https://github.com/huggingface/diffusers/pull/9562. """ algorithm_name = "pab" @@ -60,10 +61,10 @@ def get_hyperparameters(self) -> list: return [ OrdinalHyperparameter( "interval", - sequence=[2, 3, 4, 5], + sequence=[1, 2, 3, 4, 5], default_value=2, meta=dict( - desc="Interval at which to cache spatial attention blocks." + desc="Interval at which to cache spatial attention blocks - 1 disables caching." "Higher is faster but might degrade quality." ), ) @@ -110,7 +111,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: The smashed model. """ imported_modules = self.import_algorithm_packages() - # set default values + # set default values according to https://huggingface.co/docs/diffusers/en/api/cache temporal_attention_block_skip_range: Optional[int] = None cross_attention_block_skip_range: Optional[int] = None spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)