Skip to content

CUDAGRAPHs for Flux position embeddings #9299

@sayakpaul

Description

@sayakpaul

@yiyixuxu

Is it possible to refactor the Flux positional embeddings so that we can fully make use of CUDAGRAPHs?

skipping cudagraphs due to skipping cudagraphs due to cpu device (device_put). Found from : 
   File "/home/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 469, in forward
    image_rotary_emb = self.pos_embed(ids)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/embeddings.py", line 630, in forward
    self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
Code
import torch

torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True

import diffusers
from platform import python_version
from diffusers import DiffusionPipeline

print(diffusers.__version__)
print(torch.__version__)
print(python_version())


pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

for _ in range(5):
    image = pipe(
        "Happy bear",
        num_inference_steps=5,
        guidance_scale=3.5,
        max_sequence_length=512,
        generator=torch.manual_seed(42),
        height=1024,
        width=1024,
    ).images[0]

If we can fully make sure CUDAGRAPHs torch.compile() would be faster.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions