Skip to content

refactor inner attention module#2761

Merged
tianyu-l merged 1 commit intomainfrom
attention
Mar 31, 2026
Merged

refactor inner attention module#2761
tianyu-l merged 1 commit intomainfrom
attention

Conversation

@tianyu-l
Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l commented Mar 31, 2026

for several things:

  • Save a pair of transpose for varlen attention, by moving the transpose for sdpa / flex inside their own module. I had to change _ContextParallel(seq_dim=2, ...) to seq_dim=1 -> I also had to change flux model impl to adapt.
  • The above change follows [rl] remove duplicated transpose around VllmAttention in generator #2709 which is only for the generator. With this PR I'm able to remove VLLMGQAttention module added in that PR.
  • Make inner_attention a python config, instead of string attn_backend which is now used in qwen3 model_registry as a convenient way to switch config.
  • Thanks to that, we can support FlexAttention with configurable options, including block_size and kernel_options to enable Flash attention backend Add Flex flash backend to flex attention module #2171. Added a config qwen3_debugmodel_flex_flash to the registry.
  • Move annotate_flex_attention_for_regional_inductor from general attention file to graph_trainer where it's used.
  • remove outdated artificial restrictions on CP / CP + TP
  • some naming improvement

NOTE:

  • GraphTrainer tests fail on main as well
  • Model test on Llama4+PP+compile fails on main as well

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 31, 2026
@tianyu-l tianyu-l force-pushed the attention branch 4 times, most recently from dd8af40 to 5cd4db8 Compare March 31, 2026 06:24
@tianyu-l
Copy link
Copy Markdown
Contributor Author

For the RL CI failure

ValueError: use_strided_shard_as_shard_order is True, but placements: (_StridedShard(dim=1, sf=8),) is unable to be interpreted into a corresponding shard_order

I don't know why strided shard showed up, and I couldn't reproduce on pretraining Qwen3 0.6B with NGPU=2, TP=2.
cc @wwwjn @weifengpy @zpcore

@weifengpy
Copy link
Copy Markdown
Contributor

ValueError: use_strided_shard_as_shard_order is True, but placements: (_StridedShard(dim=1, sf=8),) is unable to be interpreted into a corresponding shard_order

a quick fix can be pytorch/pytorch#178735

I don't know why strided shard showed up, and I couldn't reproduce on pretraining Qwen3 0.6B with NGPU=2, TP=2. cc @wwwjn @weifengpy @zpcore

what's the config you are using to repro?

xv_packed = xv_packed.to(torch.bfloat16)

return varlen_attn(
out = varlen_attn(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the shape of varlen attention output? Is it (bs*seq, heads, dim)?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the doc here: https://docs.pytorch.org/docs/stable/nn.attention.varlen.html yes it's (total_tokens, heads, dim)


output = self.inner_attention(xq, xk, xv)
# Reshape back to the format expected by GQAttention.forward()
output = output_flat.view(batch_size, seq_len, -1, head_dim)
Copy link
Copy Markdown
Contributor

@wwwjn wwwjn Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here VLLMAttentionWrapper return's a 4D output, while VarlenAttention returns a 3D output.

I think the bug might because RL is using full DTensor in TP region , and used LocalMapAttention

class LocalMapAttention(Module):
. In this function, the out_placement is inferred from input (q/k/v)'s placement.

Previous:

  • In RL's trainer model, when TP is enabled, the input placement is Shard(1) (inner_attention gets (bsz, heads, seq, dim)), so the output varlen's placement will also be shard(1).
  • The VarlenAttention's output has shape (bsz* seq, heads, dim) with shard(1) placement.

After this PR:

  • Inner attention now gets (bsz, seq, heads, dim) as input, the input placement is Shard(2). So the LocalMapAttention will infer the output_placement to be Shard(2) as well.
  • The VarlenAttention's output still has shape (bsz* seq, heads, dim), but Shard(2) is a wrong placement

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh mismatch should be the reason, let me retry.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meaning _StridedShard shouldn't show up?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 31, 2026

I think one nit is it would be nice to have Some Readme somewhere in Attention that describes a little about how inner config passing works / where they are expected to be constructed. As well like lets say you want to write a new model def that uses attention impl Y, here is how you should do it instead of reverse engineering from other defs

@tianyu-l
Copy link
Copy Markdown
Contributor Author

I think one nit is it would be nice to have Some Readme somewhere in Attention that describes a little about how inner config passing works / where they are expected to be constructed. As well like lets say you want to write a new model def that uses attention impl Y, here is how you should do it instead of reverse engineering from other defs

@drisspg Definitely! But maybe after this wave of refactor? We are doing massive change to the codebase, especially on config system, so any doc would easily become outdated and a debt to maintain.

@tianyu-l tianyu-l requested a review from wwwjn March 31, 2026 20:57
Copy link
Copy Markdown
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a detailed look at RL part and core trainer, briefly looked in to models. The changes looks good to me , should we wait a little bit for the models CI test picked up pytorch changes and runs successfully?

"--generator.parallelism.tensor_parallel_degree 2",
"--generator.num_samples_per_prompt 2",
"--no_batch_invariant_mode",
"--trainer.compile.no-enable",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why explicitly disable trainer's compile here? cause it's by default enabled?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this test is about no_compile. Generator is already no compile, this PR makes trainer consistent.


from torchtitan.models.common import FlexAttention

model_spec = model_registry("debugmodel")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have attention_backend_override in model_registry, we should override

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have that in llama3. We only have it in qwen3. We can work in later PR to make style consistent.

@tianyu-l tianyu-l merged commit 2618196 into main Mar 31, 2026
24 of 36 checks passed
@tianyu-l tianyu-l deleted the attention branch March 31, 2026 21:24
daniellepintz added a commit that referenced this pull request Apr 9, 2026
The attention refactor in #2761 moved GQA head expansion out of the
attention modules, but the varlen_attn/varlen_attn_out calls were not
updated to pass enable_gqa=True. This causes a ValueError when query
and key/value have different numbers of heads (e.g. Hq=8, Hkv=4).

Fixes both the core VarlenAttention (trainer path) and the RL
PyTorchFlashAttentionImpl (vLLM generator path).
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
for several things:
- Save a pair of transpose for varlen attention, by moving the transpose
for sdpa / flex inside their own module. I had to change
`_ContextParallel(seq_dim=2, ...)` to `seq_dim=1` -> I also had to
change flux model impl to adapt.
- The above change follows
pytorch#2709 which is only for the
generator. With this PR I'm able to remove `VLLMGQAttention` module
added in that PR.
- Make `inner_attention` a python config, instead of string
`attn_backend` which is now used in qwen3 `model_registry` as a
convenient way to switch config.
- Thanks to that, we can support FlexAttention with configurable
options, including `block_size` and `kernel_options` to enable Flash
attention backend pytorch#2171. Added
a config `qwen3_debugmodel_flex_flash` to the registry.
- Move `annotate_flex_attention_for_regional_inductor` from general
attention file to `graph_trainer` where it's used.
- remove outdated artificial restrictions on CP / CP + TP
- some naming improvement

NOTE:
- GraphTrainer tests fail on main as well
- Model test on Llama4+PP+compile fails on main as well
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants