Skip to content

Flash/Sage varlen does not work with torch.compile #11957

@a-r-r-o-w

Description

@a-r-r-o-w

@a-r-r-o-w @DN6 @tolgacangoz
I got a bit excited about this PR and wanted to give it a go. I love the syntax, both the setter function and the context, great work!

I wanted to also see if it would still compile but got the following error logs:

[t+28s648ms]   0%|          | 0/30 [00:00<?, ?it/s]/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:1601: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
[t+28s650ms]   torch._dynamo.utils.warn_once(msg)
[t+28s666ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] Graph break from `Tensor.item()`, consider setting:
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] or:
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] to include these operations in the captured graph.
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] Graph break: from user code at:
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 733, in forward
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     encoder_hidden_states, hidden_states = block(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 456, in forward
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     attention_outputs = self.attn(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 343, in forward
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 117, in __call__
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     hidden_states = dispatch_attention_fn(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 241, in dispatch_attention_fn
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return backend_fn(**kwargs)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 962, in _sage_varlen_attention
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     _prepare_for_flash_attn_or_sage_varlen(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 351, in _prepare_for_flash_attn_or_sage_varlen
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py", line 253, in getattr_and_trace
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return fn(*args[2:], **kwargs)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 321, in _prepare_for_flash_attn_or_sage_varlen_without_mask
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     max_seqlen_q = seqlens_q.max().item()
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]
[t+28s669ms]   0%|          | 0/30 [00:00<?, ?it/s]
[t+28s670ms] [ERROR] Traceback (most recent call last):
[t+28s671ms]   File "/server/tasks.py", line 50, in run_task
[t+28s671ms]     output = await result
[t+28s671ms]              ^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/src/inference.py", line 248, in run
[t+28s671ms]     result = self.pipeline(
[t+28s671ms]              ^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[t+28s671ms]     return func(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_kontext.py", line 1063, in __call__
[t+28s671ms]     noise_pred = self.transformer(
[t+28s671ms]                  ^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 411, in __call__
[t+28s671ms]     return super().__call__(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[t+28s671ms]     return self._call_impl(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[t+28s671ms]     return forward_call(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s672ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 812, in compile_wrapper
[t+28s672ms]     raise e.with_traceback(None) from e.__cause__  # User compiler error
[t+28s672ms]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s672ms] torch._dynamo.exc.Unsupported: Unsupported Tensor.item() call with capture_scalar_outputs=False
[t+28s672ms]   Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.

compiling with both

self.pipeline.transformer.compile_repeated_blocks(fullgraph=True)

and

self.pipeline.transformer.to(memory_format=torch.channels_last)
self.pipeline.transformer = torch.compile(
  self.pipeline.transformer, mode="max-autotune", fullgraph=True
)

yields the same result

after

self.pipeline.transformer.set_attention_backend("sage_varlen")

Originally posted by @okaris in #11916 (comment)

Metadata

Metadata

Assignees

Labels

staleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions