diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py index 40b7e3e7e..0492669db 100644 --- a/QEfficient/diffusers/models/transformers/transformer_flux.py +++ b/QEfficient/diffusers/models/transformers/transformer_flux.py @@ -4,10 +4,11 @@ # SPDX-License-Identifier: BSD-3-Clause # # ---------------------------------------------------------------------------- -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Type, Union import numpy as np import torch +import torch.nn as nn from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.transformers.transformer_flux import ( FluxAttention, @@ -221,6 +222,15 @@ def forward( class QEffFluxTransformer2DModel(FluxTransformer2DModel): + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock} + def forward( self, hidden_states: torch.Tensor, diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index 31d3be2ce..9200997d7 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -13,15 +13,17 @@ and combined QKV-blocking. """ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch +import torch.nn as nn from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.transformers.transformer_wan import ( WanAttention, WanAttnProcessor, WanTransformer3DModel, + WanTransformerBlock, _get_qkv_projections, ) from diffusers.utils import set_weights_and_activate_adapters @@ -289,3 +291,78 @@ def forward( return (output,) return Transformer2DModelOutput(sample=output) + + +class QEffWanUnifiedWrapper(nn.Module): + """ + A wrapper class that combines WAN high and low noise transformers into a single unified transformer. + + This wrapper dynamically selects between high and low noise transformers based on the timestep shape + in the ONNX graph during inference. This approach enables efficient deployment of both transformer + variants in a single model. + + Attributes: + transformer_high(nn.Module): The high noise transformer component + transformer_low(nn.Module): The low noise transformer component + config: Configuration shared between both transformers (from high noise transformer) + """ + + def __init__(self, transformer_high, transformer_low): + super().__init__() + self.transformer_high = transformer_high + self.transformer_low = transformer_low + # Both high and low noise transformers share the same configuration + self.config = transformer_high.config + + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {WanTransformerBlock} + + def forward( + self, + hidden_states, + encoder_hidden_states, + rotary_emb, + temb, + timestep_proj, + tsp, + attention_kwargs=None, + return_dict=False, + ): + # Condition based on timestep shape + is_high_noise = tsp.shape[0] == torch.tensor(1) + + high_hs = hidden_states.detach() + ehs = encoder_hidden_states.detach() + rhs = rotary_emb.detach() + ths = temb.detach() + projhs = timestep_proj.detach() + + noise_pred_high = self.transformer_high( + hidden_states=high_hs, + encoder_hidden_states=ehs, + rotary_emb=rhs, + temb=ths, + timestep_proj=projhs, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + )[0] + + noise_pred_low = self.transformer_low( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_emb=rotary_emb, + temb=temb, + timestep_proj=timestep_proj, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + )[0] + + # Select based on timestep condition + noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low) + return noise_pred diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 4cc70d056..9b4ca89d8 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn -from diffusers.models.transformers.transformer_wan import WanTransformerBlock from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform @@ -18,10 +17,6 @@ CustomOpsTransform, NormalizationTransform, ) -from QEfficient.diffusers.models.transformers.transformer_flux import ( - QEffFluxSingleTransformerBlock, - QEffFluxTransformerBlock, -) from QEfficient.transformers.models.pytorch_transforms import ( T5ModelTransform, ) @@ -475,7 +470,6 @@ def export( output_names: List[str], dynamic_axes: Dict, export_dir: str = None, - export_kwargs: Dict = {}, use_onnx_subfunctions: bool = False, ) -> str: """ @@ -486,7 +480,6 @@ def export( output_names (List[str]): Names of model outputs dynamic_axes (Dict): Specification of dynamic dimensions export_dir (str, optional): Directory to save ONNX model - export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions) use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions for better modularity and potential optimization @@ -494,22 +487,15 @@ def export( str: Path to the exported ONNX model """ - if use_onnx_subfunctions: - export_kwargs = { - "export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}, - "use_onnx_subfunctions": True, - } - # Sort _use_default_values in config to ensure consistent hash generation during export self.model.config["_use_default_values"].sort() - return self._export( example_inputs=inputs, output_names=output_names, dynamic_axes=dynamic_axes, export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, offload_pt_weights=False, # As weights are needed with AdaLN changes - **export_kwargs, ) def compile(self, specializations: List[Dict], **compiler_options) -> None: @@ -631,7 +617,6 @@ def export( output_names: List[str], dynamic_axes: Dict, export_dir: str = None, - export_kwargs: Dict = {}, use_onnx_subfunctions: bool = False, ) -> str: """Export the Wan transformer model to ONNX format. @@ -641,14 +626,11 @@ def export( output_names (List[str]): Names of model outputs dynamic_axes (Dict): Specification of dynamic dimensions export_dir (str, optional): Directory to save ONNX model - export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions) use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions for better modularity and potential optimization Returns: str: Path to the exported ONNX model """ - if use_onnx_subfunctions: - export_kwargs = {"export_modules_as_functions": {WanTransformerBlock}, "use_onnx_subfunctions": True} return self._export( example_inputs=inputs, @@ -656,7 +638,7 @@ def export( dynamic_axes=dynamic_axes, export_dir=export_dir, offload_pt_weights=True, - **export_kwargs, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile(self, specializations, **compiler_options) -> None: diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py index 7ffa4b043..b69e4d49d 100644 --- a/QEfficient/diffusers/pipelines/pipeline_utils.py +++ b/QEfficient/diffusers/pipelines/pipeline_utils.py @@ -13,8 +13,6 @@ import numpy as np import PIL.Image -import torch -import torch.nn as nn from tqdm import tqdm from QEfficient.utils._utils import load_json @@ -297,69 +295,3 @@ def __repr__(self): # List of module name that require special handling during export # when use_onnx_subfunctions is enabled ONNX_SUBFUNCTION_MODULE = ["transformer"] - - -class QEffWanUnifiedWrapper(nn.Module): - """ - A wrapper class that combines WAN high and low noise transformers into a single unified transformer. - - This wrapper dynamically selects between high and low noise transformers based on the timestep shape - in the ONNX graph during inference. This approach enables efficient deployment of both transformer - variants in a single model. - - Attributes: - transformer_high(nn.Module): The high noise transformer component - transformer_low(nn.Module): The low noise transformer component - config: Configuration shared between both transformers (from high noise transformer) - """ - - def __init__(self, transformer_high, transformer_low): - super().__init__() - self.transformer_high = transformer_high - self.transformer_low = transformer_low - # Both high and low noise transformers share the same configuration - self.config = transformer_high.config - - def forward( - self, - hidden_states, - encoder_hidden_states, - rotary_emb, - temb, - timestep_proj, - tsp, - attention_kwargs=None, - return_dict=False, - ): - # Condition based on timestep shape - is_high_noise = tsp.shape[0] == torch.tensor(1) - - high_hs = hidden_states.detach() - ehs = encoder_hidden_states.detach() - rhs = rotary_emb.detach() - ths = temb.detach() - projhs = timestep_proj.detach() - - noise_pred_high = self.transformer_high( - hidden_states=high_hs, - encoder_hidden_states=ehs, - rotary_emb=rhs, - temb=ths, - timestep_proj=projhs, - attention_kwargs=attention_kwargs, - return_dict=return_dict, - )[0] - - noise_pred_low = self.transformer_low( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - rotary_emb=rotary_emb, - temb=temb, - timestep_proj=timestep_proj, - attention_kwargs=attention_kwargs, - return_dict=return_dict, - )[0] - - # Select based on timestep condition - noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low) - return noise_pred diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index ca0444406..74512ac24 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -23,12 +23,12 @@ from diffusers import WanPipeline from tqdm import tqdm +from QEfficient.diffusers.models.transformers.transformer_wan import QEffWanUnifiedWrapper from QEfficient.diffusers.pipelines.pipeline_module import QEffVAE, QEffWanUnifiedTransformer from QEfficient.diffusers.pipelines.pipeline_utils import ( ONNX_SUBFUNCTION_MODULE, ModulePerf, QEffPipelineOutput, - QEffWanUnifiedWrapper, calculate_latent_dimensions_with_frames, compile_modules_parallel, compile_modules_sequential, diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py index 3a954556f..da3231190 100644 --- a/QEfficient/utils/export_utils.py +++ b/QEfficient/utils/export_utils.py @@ -179,7 +179,6 @@ def _setup_onnx_subfunctions(qeff_model, args, kwargs): qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform) qeff_model._onnx_transforms.append(CustomOpTransform) - # TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation submodule_classes = qeff_model.model.get_submodules_for_export() if submodule_classes: kwargs["export_modules_as_functions"] = submodule_classes diff --git a/QEfficient/utils/torch_patches.py b/QEfficient/utils/torch_patches.py index 46485920c..b0fbcc45e 100644 --- a/QEfficient/utils/torch_patches.py +++ b/QEfficient/utils/torch_patches.py @@ -40,6 +40,7 @@ def _track_module_attributes_forward_hook(module, input, output): onnx_attrs = getattr(module, attr_name) delattr(module, attr_name) try: + onnx_attrs = {} # HACK: to reduce export time # TODO: study behaviour across models _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) except Exception: logger.warning("Failed to track ONNX scope attributes, Skipping this step.") diff --git a/examples/diffusers/wan/wan_lightning.py b/examples/diffusers/wan/wan_lightning.py index aca2b9754..def5cc29a 100644 --- a/examples/diffusers/wan/wan_lightning.py +++ b/examples/diffusers/wan/wan_lightning.py @@ -52,7 +52,7 @@ def load_wan_lora(path: str): generator=torch.manual_seed(0), height=480, width=832, - use_onnx_subfunctions=False, + use_onnx_subfunctions=True, parallel_compile=True, ) frames = output.images[0] diff --git a/tests/diffusers/flux_test_config.json b/tests/diffusers/flux_test_config.json index 6d22986ce..581a2dd99 100644 --- a/tests/diffusers/flux_test_config.json +++ b/tests/diffusers/flux_test_config.json @@ -3,8 +3,7 @@ "height": 256, "width": 256, "num_transformer_layers": 2, - "num_single_layers": 2, - "use_onnx_subfunctions": false + "num_single_layers": 2 }, "mad_validation": { "tolerances": { @@ -21,7 +20,8 @@ "max_sequence_length": 256, "validate_gen_img": true, "min_image_variance": 1.0, - "custom_config_path": null + "custom_config_path": null, + "use_onnx_subfunctions": true }, "validation_checks": { "image_generation": true, diff --git a/tests/diffusers/test_flux.py b/tests/diffusers/test_flux.py index 6c33540c3..3d3d753ff 100644 --- a/tests/diffusers/test_flux.py +++ b/tests/diffusers/test_flux.py @@ -56,6 +56,7 @@ def flux_pipeline_call_with_mad_validation( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, custom_config_path: Optional[str] = None, + use_onnx_subfunctions: bool = False, parallel_compile: bool = False, mad_tolerances: Dict[str, float] = None, ): @@ -72,7 +73,13 @@ def flux_pipeline_call_with_mad_validation( device = "cpu" # Step 1: Load configuration, compile models - pipeline.compile(compile_config=custom_config_path, parallel=parallel_compile, height=height, width=width) + pipeline.compile( + compile_config=custom_config_path, + parallel=parallel_compile, + use_onnx_subfunctions=use_onnx_subfunctions, + height=height, + width=width, + ) # Validate all inputs pipeline.model.check_inputs( @@ -307,10 +314,7 @@ def flux_pipeline(): """Setup compiled Flux pipeline for testing""" config = INITIAL_TEST_CONFIG["model_setup"] - pipeline = QEffFluxPipeline.from_pretrained( - "black-forest-labs/FLUX.1-schnell", - use_onnx_subfunctions=config["use_onnx_subfunctions"], - ) + pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") # Reduce to 2 layers for testing original_blocks = pipeline.transformer.model.transformer_blocks @@ -382,6 +386,7 @@ def test_flux_pipeline(flux_pipeline): custom_config_path=CONFIG_PATH, generator=generator, mad_tolerances=config["mad_validation"]["tolerances"], + use_onnx_subfunctions=config["pipeline_params"]["use_onnx_subfunctions"], parallel_compile=True, return_dict=True, )