Skip to content
Merged
12 changes: 11 additions & 1 deletion QEfficient/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Type, Union

import numpy as np
import torch
import torch.nn as nn
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.transformers.transformer_flux import (
FluxAttention,
Expand Down Expand Up @@ -221,6 +222,15 @@ def forward(


class QEffFluxTransformer2DModel(FluxTransformer2DModel):
def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Return the set of class used as the repeated layer across the model for subfunction extraction.
Notes:
This method should return the *class object* (not an instance).
Downstream code can use this to find/build subfunctions for repeated blocks.
"""
return {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
79 changes: 78 additions & 1 deletion QEfficient/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
and combined QKV-blocking.
"""

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.transformers.transformer_wan import (
WanAttention,
WanAttnProcessor,
WanTransformer3DModel,
WanTransformerBlock,
_get_qkv_projections,
)
from diffusers.utils import set_weights_and_activate_adapters
Expand Down Expand Up @@ -289,3 +291,78 @@ def forward(
return (output,)

return Transformer2DModelOutput(sample=output)


class QEffWanUnifiedWrapper(nn.Module):
"""
A wrapper class that combines WAN high and low noise transformers into a single unified transformer.

This wrapper dynamically selects between high and low noise transformers based on the timestep shape
in the ONNX graph during inference. This approach enables efficient deployment of both transformer
variants in a single model.

Attributes:
transformer_high(nn.Module): The high noise transformer component
transformer_low(nn.Module): The low noise transformer component
config: Configuration shared between both transformers (from high noise transformer)
"""

def __init__(self, transformer_high, transformer_low):
super().__init__()
self.transformer_high = transformer_high
self.transformer_low = transformer_low
# Both high and low noise transformers share the same configuration
self.config = transformer_high.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Return the set of class used as the repeated layer across the model for subfunction extraction.
Notes:
This method should return the *class object* (not an instance).
Downstream code can use this to find/build subfunctions for repeated blocks.
"""
return {WanTransformerBlock}

def forward(
self,
hidden_states,
encoder_hidden_states,
rotary_emb,
temb,
timestep_proj,
tsp,
attention_kwargs=None,
return_dict=False,
):
# Condition based on timestep shape
is_high_noise = tsp.shape[0] == torch.tensor(1)

high_hs = hidden_states.detach()
ehs = encoder_hidden_states.detach()
rhs = rotary_emb.detach()
ths = temb.detach()
projhs = timestep_proj.detach()

noise_pred_high = self.transformer_high(
hidden_states=high_hs,
encoder_hidden_states=ehs,
rotary_emb=rhs,
temb=ths,
timestep_proj=projhs,
attention_kwargs=attention_kwargs,
return_dict=return_dict,
)[0]

noise_pred_low = self.transformer_low(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
rotary_emb=rotary_emb,
temb=temb,
timestep_proj=timestep_proj,
attention_kwargs=attention_kwargs,
return_dict=return_dict,
)[0]

# Select based on timestep condition
noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low)
return noise_pred
22 changes: 2 additions & 20 deletions QEfficient/diffusers/pipelines/pipeline_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import torch
import torch.nn as nn
from diffusers.models.transformers.transformer_wan import WanTransformerBlock

from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
Expand All @@ -18,10 +17,6 @@
CustomOpsTransform,
NormalizationTransform,
)
from QEfficient.diffusers.models.transformers.transformer_flux import (
QEffFluxSingleTransformerBlock,
QEffFluxTransformerBlock,
)
from QEfficient.transformers.models.pytorch_transforms import (
T5ModelTransform,
)
Expand Down Expand Up @@ -475,7 +470,6 @@ def export(
output_names: List[str],
dynamic_axes: Dict,
export_dir: str = None,
export_kwargs: Dict = {},
use_onnx_subfunctions: bool = False,
) -> str:
"""
Expand All @@ -486,30 +480,22 @@ def export(
output_names (List[str]): Names of model outputs
dynamic_axes (Dict): Specification of dynamic dimensions
export_dir (str, optional): Directory to save ONNX model
export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
for better modularity and potential optimization

Returns:
str: Path to the exported ONNX model
"""

if use_onnx_subfunctions:
export_kwargs = {
"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock},
"use_onnx_subfunctions": True,
}

# Sort _use_default_values in config to ensure consistent hash generation during export
self.model.config["_use_default_values"].sort()

return self._export(
example_inputs=inputs,
output_names=output_names,
dynamic_axes=dynamic_axes,
export_dir=export_dir,
use_onnx_subfunctions=use_onnx_subfunctions,
offload_pt_weights=False, # As weights are needed with AdaLN changes
**export_kwargs,
)

def compile(self, specializations: List[Dict], **compiler_options) -> None:
Expand Down Expand Up @@ -631,7 +617,6 @@ def export(
output_names: List[str],
dynamic_axes: Dict,
export_dir: str = None,
export_kwargs: Dict = {},
use_onnx_subfunctions: bool = False,
) -> str:
"""Export the Wan transformer model to ONNX format.
Expand All @@ -641,22 +626,19 @@ def export(
output_names (List[str]): Names of model outputs
dynamic_axes (Dict): Specification of dynamic dimensions
export_dir (str, optional): Directory to save ONNX model
export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
for better modularity and potential optimization
Returns:
str: Path to the exported ONNX model
"""
if use_onnx_subfunctions:
export_kwargs = {"export_modules_as_functions": {WanTransformerBlock}, "use_onnx_subfunctions": True}

return self._export(
example_inputs=inputs,
output_names=output_names,
dynamic_axes=dynamic_axes,
export_dir=export_dir,
offload_pt_weights=True,
**export_kwargs,
use_onnx_subfunctions=use_onnx_subfunctions,
)

def compile(self, specializations, **compiler_options) -> None:
Expand Down
68 changes: 0 additions & 68 deletions QEfficient/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

import numpy as np
import PIL.Image
import torch
import torch.nn as nn
from tqdm import tqdm

from QEfficient.utils._utils import load_json
Expand Down Expand Up @@ -297,69 +295,3 @@ def __repr__(self):
# List of module name that require special handling during export
# when use_onnx_subfunctions is enabled
ONNX_SUBFUNCTION_MODULE = ["transformer"]


class QEffWanUnifiedWrapper(nn.Module):
"""
A wrapper class that combines WAN high and low noise transformers into a single unified transformer.

This wrapper dynamically selects between high and low noise transformers based on the timestep shape
in the ONNX graph during inference. This approach enables efficient deployment of both transformer
variants in a single model.

Attributes:
transformer_high(nn.Module): The high noise transformer component
transformer_low(nn.Module): The low noise transformer component
config: Configuration shared between both transformers (from high noise transformer)
"""

def __init__(self, transformer_high, transformer_low):
super().__init__()
self.transformer_high = transformer_high
self.transformer_low = transformer_low
# Both high and low noise transformers share the same configuration
self.config = transformer_high.config

def forward(
self,
hidden_states,
encoder_hidden_states,
rotary_emb,
temb,
timestep_proj,
tsp,
attention_kwargs=None,
return_dict=False,
):
# Condition based on timestep shape
is_high_noise = tsp.shape[0] == torch.tensor(1)

high_hs = hidden_states.detach()
ehs = encoder_hidden_states.detach()
rhs = rotary_emb.detach()
ths = temb.detach()
projhs = timestep_proj.detach()

noise_pred_high = self.transformer_high(
hidden_states=high_hs,
encoder_hidden_states=ehs,
rotary_emb=rhs,
temb=ths,
timestep_proj=projhs,
attention_kwargs=attention_kwargs,
return_dict=return_dict,
)[0]

noise_pred_low = self.transformer_low(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
rotary_emb=rotary_emb,
temb=temb,
timestep_proj=timestep_proj,
attention_kwargs=attention_kwargs,
return_dict=return_dict,
)[0]

# Select based on timestep condition
noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low)
return noise_pred
2 changes: 1 addition & 1 deletion QEfficient/diffusers/pipelines/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from diffusers import WanPipeline
from tqdm import tqdm

from QEfficient.diffusers.models.transformers.transformer_wan import QEffWanUnifiedWrapper
from QEfficient.diffusers.pipelines.pipeline_module import QEffVAE, QEffWanUnifiedTransformer
from QEfficient.diffusers.pipelines.pipeline_utils import (
ONNX_SUBFUNCTION_MODULE,
ModulePerf,
QEffPipelineOutput,
QEffWanUnifiedWrapper,
calculate_latent_dimensions_with_frames,
compile_modules_parallel,
compile_modules_sequential,
Expand Down
1 change: 0 additions & 1 deletion QEfficient/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def _setup_onnx_subfunctions(qeff_model, args, kwargs):
qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform)
qeff_model._onnx_transforms.append(CustomOpTransform)

# TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation
submodule_classes = qeff_model.model.get_submodules_for_export()
if submodule_classes:
kwargs["export_modules_as_functions"] = submodule_classes
Expand Down
1 change: 1 addition & 0 deletions QEfficient/utils/torch_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _track_module_attributes_forward_hook(module, input, output):
onnx_attrs = getattr(module, attr_name)
delattr(module, attr_name)
try:
onnx_attrs = {} # HACK: to reduce export time # TODO: study behaviour across models
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
except Exception:
logger.warning("Failed to track ONNX scope attributes, Skipping this step.")
Expand Down
2 changes: 1 addition & 1 deletion examples/diffusers/wan/wan_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def load_wan_lora(path: str):
generator=torch.manual_seed(0),
height=480,
width=832,
use_onnx_subfunctions=False,
use_onnx_subfunctions=True,
parallel_compile=True,
)
frames = output.images[0]
Expand Down
6 changes: 3 additions & 3 deletions tests/diffusers/flux_test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
"height": 256,
"width": 256,
"num_transformer_layers": 2,
"num_single_layers": 2,
"use_onnx_subfunctions": false
"num_single_layers": 2
},
"mad_validation": {
"tolerances": {
Expand All @@ -21,7 +20,8 @@
"max_sequence_length": 256,
"validate_gen_img": true,
"min_image_variance": 1.0,
"custom_config_path": null
"custom_config_path": null,
"use_onnx_subfunctions": true
},
"validation_checks": {
"image_generation": true,
Expand Down
15 changes: 10 additions & 5 deletions tests/diffusers/test_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def flux_pipeline_call_with_mad_validation(
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
custom_config_path: Optional[str] = None,
use_onnx_subfunctions: bool = False,
parallel_compile: bool = False,
mad_tolerances: Dict[str, float] = None,
):
Expand All @@ -72,7 +73,13 @@ def flux_pipeline_call_with_mad_validation(
device = "cpu"

# Step 1: Load configuration, compile models
pipeline.compile(compile_config=custom_config_path, parallel=parallel_compile, height=height, width=width)
pipeline.compile(
compile_config=custom_config_path,
parallel=parallel_compile,
use_onnx_subfunctions=use_onnx_subfunctions,
height=height,
width=width,
)

# Validate all inputs
pipeline.model.check_inputs(
Expand Down Expand Up @@ -307,10 +314,7 @@ def flux_pipeline():
"""Setup compiled Flux pipeline for testing"""
config = INITIAL_TEST_CONFIG["model_setup"]

pipeline = QEffFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
use_onnx_subfunctions=config["use_onnx_subfunctions"],
)
pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")

# Reduce to 2 layers for testing
original_blocks = pipeline.transformer.model.transformer_blocks
Expand Down Expand Up @@ -382,6 +386,7 @@ def test_flux_pipeline(flux_pipeline):
custom_config_path=CONFIG_PATH,
generator=generator,
mad_tolerances=config["mad_validation"]["tolerances"],
use_onnx_subfunctions=config["pipeline_params"]["use_onnx_subfunctions"],
parallel_compile=True,
return_dict=True,
)
Expand Down
Loading