-
Notifications
You must be signed in to change notification settings - Fork 77
feat: add fastercache and pab #92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
8dded7e
fix: correct docstring in deepcache
nifleisch c14a071
feat: add model checks
nifleisch cbd0e71
feat: add pyramid attention broadcast (pab) cacher
nifleisch 9661554
feat: add fastercache cacher
nifleisch 327526b
tests: add flux tiny random fixture
nifleisch 9bba3ae
tests: add algorithms tests for pab and fastercache
nifleisch 3743bfc
tests: add combination tests for pab and fastercache
nifleisch 3f2213f
fix: add 1 as value for interval parameter
nifleisch File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| # Copyright 2025 - Pruna AI GmbH. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Dict, Optional, Tuple | ||
|
|
||
| from ConfigSpace import OrdinalHyperparameter | ||
|
|
||
| from pruna.algorithms.caching import PrunaCacher | ||
| from pruna.config.smash_config import SmashConfigPrefixWrapper | ||
| from pruna.engine.model_checks import ( | ||
| is_allegro_pipeline, | ||
| is_cogvideo_pipeline, | ||
| is_flux_pipeline, | ||
| is_hunyuan_pipeline, | ||
| is_latte_pipeline, | ||
| is_mochi_pipeline, | ||
| is_wan_pipeline, | ||
| ) | ||
| from pruna.logging.logger import pruna_logger | ||
|
|
||
|
|
||
| class FasterCacheCacher(PrunaCacher): | ||
| """ | ||
| Implement FasterCache. | ||
|
|
||
| FasterCache is a method that speeds up inference in diffusion transformers by: | ||
| - Reusing attention states between successive inference steps, due to high similarity between them | ||
| - Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between | ||
| unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional | ||
| branch output using the conditional branch output | ||
| This implementation reduces the number of tunable parameters by setting pipeline specific parameters according to | ||
| https://github.com/huggingface/diffusers/pull/9562. | ||
| """ | ||
|
|
||
| algorithm_name = "fastercache" | ||
| references = {"GitHub": "https://github.com/Vchitect/FasterCache", "Paper": "https://arxiv.org/abs/2410.19355"} | ||
| tokenizer_required = False | ||
| processor_required = False | ||
| dataset_required = False | ||
| run_on_cpu = True | ||
| run_on_cuda = True | ||
| compatible_algorithms = dict(quantizer=["hqq_diffusers", "diffusers_int8"]) | ||
|
|
||
| def get_hyperparameters(self) -> list: | ||
| """ | ||
| Get the hyperparameters for the algorithm. | ||
|
|
||
| Returns | ||
| ------- | ||
| list | ||
| The hyperparameters. | ||
| """ | ||
| return [ | ||
| OrdinalHyperparameter( | ||
| "interval", | ||
| sequence=[1, 2, 3, 4, 5], | ||
| default_value=2, | ||
| meta=dict( | ||
| desc="Interval at which to cache spatial attention blocks - 1 disables caching." | ||
| "Higher is faster but might degrade quality." | ||
| ), | ||
| ), | ||
| ] | ||
|
|
||
| def model_check_fn(self, model: Any) -> bool: | ||
| """ | ||
| Check if the model is a valid model for the algorithm. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
| The model to check. | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| True if the model is a valid model for the algorithm, False otherwise. | ||
| """ | ||
| pipeline_check_fns = [ | ||
| is_allegro_pipeline, | ||
| is_cogvideo_pipeline, | ||
| is_flux_pipeline, | ||
| is_hunyuan_pipeline, | ||
| is_mochi_pipeline, | ||
| is_wan_pipeline, | ||
| ] | ||
| return any(is_pipeline(model) for is_pipeline in pipeline_check_fns) | ||
|
|
||
| def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: | ||
| """ | ||
| Apply the fastercache algorithm to the model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
| The model to apply the fastercache algorithm to. | ||
| smash_config : SmashConfigPrefixWrapper | ||
| The configuration for the caching. | ||
|
|
||
| Returns | ||
| ------- | ||
| Any | ||
| The smashed model. | ||
| """ | ||
| imported_modules = self.import_algorithm_packages() | ||
| # set default values according to https://huggingface.co/docs/diffusers/en/api/cache | ||
| temporal_attention_block_skip_range: Optional[int] = None | ||
| spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) | ||
| temporal_attention_timestep_skip_range: Optional[Tuple[int, int]] = None | ||
| low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901) | ||
| high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301) | ||
| unconditional_batch_skip_range: int = 5 | ||
| unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641) | ||
| spatial_attention_block_identifiers: Tuple[str, ...] = ( | ||
| "blocks.*attn1", | ||
| "transformer_blocks.*attn1", | ||
| "single_transformer_blocks.*attn1" | ||
| ) | ||
| temporal_attention_block_identifiers: Tuple[str, ...] = ("temporal_transformer_blocks.*attn1",) | ||
| attention_weight_callback = lambda _: 0.5 # noqa: E731 | ||
| tensor_format: str = "BFCHW" | ||
| is_guidance_distilled: bool = False | ||
|
|
||
| # set configs according to https://github.com/huggingface/diffusers/pull/9562 | ||
| if is_allegro_pipeline(model): | ||
| low_frequency_weight_update_timestep_range = (99, 641) | ||
| spatial_attention_block_identifiers = ("transformer_blocks",) | ||
| elif is_cogvideo_pipeline(model): | ||
| low_frequency_weight_update_timestep_range = (99, 641) | ||
| spatial_attention_block_identifiers = ("transformer_blocks",) | ||
| attention_weight_callback = lambda _: 0.3 # noqa: E731 | ||
| elif is_flux_pipeline(model): | ||
| spatial_attention_timestep_skip_range = (-1, 961) | ||
| spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",) | ||
| tensor_format = "BCHW" | ||
| is_guidance_distilled = True | ||
| elif is_hunyuan_pipeline(model): | ||
| spatial_attention_timestep_skip_range = (99, 941) | ||
| spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",) | ||
| tensor_format = "BCFHW" | ||
| is_guidance_distilled = True | ||
| elif is_latte_pipeline(model): | ||
| temporal_attention_block_skip_range = 2 | ||
| temporal_attention_timestep_skip_range = (-1, 681) | ||
| low_frequency_weight_update_timestep_range = (99, 641) | ||
| spatial_attention_block_identifiers = ("transformer_blocks.*attn1",) | ||
| temporal_attention_block_identifiers = ("temporal_transformer_blocks",) | ||
| elif is_mochi_pipeline(model): | ||
| spatial_attention_timestep_skip_range = (-1, 981) | ||
| low_frequency_weight_update_timestep_range = (301, 961) | ||
| high_frequency_weight_update_timestep_range = (-1, 851) | ||
| unconditional_batch_skip_range = 4 | ||
| unconditional_batch_timestep_skip_range = (-1, 975) | ||
| spatial_attention_block_identifiers = ("transformer_blocks",) | ||
| attention_weight_callback = lambda _: 0.6 # noqa: E731 | ||
| elif is_wan_pipeline(model): | ||
| spatial_attention_block_identifiers = ("blocks",) | ||
| tensor_format = "BCFHW" | ||
| is_guidance_distilled = True | ||
|
|
||
| fastercache_config = imported_modules["FasterCacheConfig"]( | ||
| spatial_attention_block_skip_range=smash_config["interval"], | ||
| temporal_attention_block_skip_range=temporal_attention_block_skip_range, | ||
| spatial_attention_timestep_skip_range=spatial_attention_timestep_skip_range, | ||
| temporal_attention_timestep_skip_range=temporal_attention_timestep_skip_range, | ||
| low_frequency_weight_update_timestep_range=low_frequency_weight_update_timestep_range, | ||
| high_frequency_weight_update_timestep_range=high_frequency_weight_update_timestep_range, | ||
| alpha_low_frequency=1.1, | ||
| alpha_high_frequency=1.1, | ||
| unconditional_batch_skip_range=unconditional_batch_skip_range, | ||
| unconditional_batch_timestep_skip_range=unconditional_batch_timestep_skip_range, | ||
| spatial_attention_block_identifiers=spatial_attention_block_identifiers, | ||
| temporal_attention_block_identifiers=temporal_attention_block_identifiers, | ||
| attention_weight_callback=attention_weight_callback, | ||
| tensor_format=tensor_format, | ||
| current_timestep_callback=lambda: model.current_timestep, | ||
| is_guidance_distilled=is_guidance_distilled, | ||
| ) | ||
| model.transformer.enable_cache(fastercache_config) | ||
| return model | ||
|
|
||
| def import_algorithm_packages(self) -> Dict[str, Any]: | ||
| """ | ||
| Import the algorithm packages. | ||
|
|
||
| Returns | ||
| ------- | ||
| Dict[str, Any] | ||
| The algorithm packages. | ||
| """ | ||
| try: | ||
| from diffusers import FasterCacheConfig | ||
| except ModuleNotFoundError: | ||
| pruna_logger.error( | ||
| "You are trying to use FasterCache, but the FasterCacheConfig can not be imported from diffusers. " | ||
| "This is likely because your diffusers version is too old." | ||
| ) | ||
| raise | ||
|
|
||
| return dict(FasterCacheConfig=FasterCacheConfig) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would still recommend to put them in the smash config as constant and mention that these can be overwritten for different architecture with a link to the code file or the diffuser PR so that the documentation is complete :)