@yiyixuxu
Is it possible to refactor the Flux positional embeddings so that we can fully make use of CUDAGRAPHs?
skipping cudagraphs due to skipping cudagraphs due to cpu device (device_put). Found from :
File "/home/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 469, in forward
image_rotary_emb = self.pos_embed(ids)
File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sayak/diffusers/src/diffusers/models/embeddings.py", line 630, in forward
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
Code
import torch
torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True
import diffusers
from platform import python_version
from diffusers import DiffusionPipeline
print(diffusers.__version__)
print(torch.__version__)
print(python_version())
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
for _ in range(5):
image = pipe(
"Happy bear",
num_inference_steps=5,
guidance_scale=3.5,
max_sequence_length=512,
generator=torch.manual_seed(42),
height=1024,
width=1024,
).images[0]
If we can fully make sure CUDAGRAPHs torch.compile() would be faster.
@yiyixuxu
Is it possible to refactor the Flux positional embeddings so that we can fully make use of CUDAGRAPHs?
Code
If we can fully make sure CUDAGRAPHs
torch.compile()would be faster.