Conversation
dd8af40 to
5cd4db8
Compare
|
For the RL CI failure
I don't know why strided shard showed up, and I couldn't reproduce on pretraining Qwen3 0.6B with NGPU=2, TP=2. |
a quick fix can be pytorch/pytorch#178735
what's the config you are using to repro? |
| xv_packed = xv_packed.to(torch.bfloat16) | ||
|
|
||
| return varlen_attn( | ||
| out = varlen_attn( |
There was a problem hiding this comment.
What is the shape of varlen attention output? Is it (bs*seq, heads, dim)?
There was a problem hiding this comment.
From the doc here: https://docs.pytorch.org/docs/stable/nn.attention.varlen.html yes it's (total_tokens, heads, dim)
|
|
||
| output = self.inner_attention(xq, xk, xv) | ||
| # Reshape back to the format expected by GQAttention.forward() | ||
| output = output_flat.view(batch_size, seq_len, -1, head_dim) |
There was a problem hiding this comment.
Here VLLMAttentionWrapper return's a 4D output, while VarlenAttention returns a 3D output.
I think the bug might because RL is using full DTensor in TP region , and used LocalMapAttention
. In this function, theout_placement is inferred from input (q/k/v)'s placement.
Previous:
- In RL's trainer model, when TP is enabled, the input placement is Shard(1) (inner_attention gets (bsz, heads, seq, dim)), so the output varlen's placement will also be shard(1).
- The VarlenAttention's output has shape (bsz* seq, heads, dim) with shard(1) placement.
After this PR:
- Inner attention now gets (bsz, seq, heads, dim) as input, the input placement is Shard(2). So the LocalMapAttention will infer the
output_placementto be Shard(2) as well. - The VarlenAttention's output still has shape (bsz* seq, heads, dim), but Shard(2) is a wrong placement
There was a problem hiding this comment.
oh mismatch should be the reason, let me retry.
There was a problem hiding this comment.
meaning _StridedShard shouldn't show up?
|
I think one nit is it would be nice to have Some Readme somewhere in Attention that describes a little about how inner config passing works / where they are expected to be constructed. As well like lets say you want to write a new model def that uses attention impl Y, here is how you should do it instead of reverse engineering from other defs |
@drisspg Definitely! But maybe after this wave of refactor? We are doing massive change to the codebase, especially on config system, so any doc would easily become outdated and a debt to maintain. |
wwwjn
left a comment
There was a problem hiding this comment.
Took a detailed look at RL part and core trainer, briefly looked in to models. The changes looks good to me , should we wait a little bit for the models CI test picked up pytorch changes and runs successfully?
| "--generator.parallelism.tensor_parallel_degree 2", | ||
| "--generator.num_samples_per_prompt 2", | ||
| "--no_batch_invariant_mode", | ||
| "--trainer.compile.no-enable", |
There was a problem hiding this comment.
Why explicitly disable trainer's compile here? cause it's by default enabled?
There was a problem hiding this comment.
Because this test is about no_compile. Generator is already no compile, this PR makes trainer consistent.
|
|
||
| from torchtitan.models.common import FlexAttention | ||
|
|
||
| model_spec = model_registry("debugmodel") |
There was a problem hiding this comment.
We already have attention_backend_override in model_registry, we should override
There was a problem hiding this comment.
I don't think we have that in llama3. We only have it in qwen3. We can work in later PR to make style consistent.
The attention refactor in #2761 moved GQA head expansion out of the attention modules, but the varlen_attn/varlen_attn_out calls were not updated to pass enable_gqa=True. This causes a ValueError when query and key/value have different numbers of heads (e.g. Hq=8, Hkv=4). Fixes both the core VarlenAttention (trainer path) and the RL PyTorchFlashAttentionImpl (vLLM generator path).
for several things: - Save a pair of transpose for varlen attention, by moving the transpose for sdpa / flex inside their own module. I had to change `_ContextParallel(seq_dim=2, ...)` to `seq_dim=1` -> I also had to change flux model impl to adapt. - The above change follows pytorch#2709 which is only for the generator. With this PR I'm able to remove `VLLMGQAttention` module added in that PR. - Make `inner_attention` a python config, instead of string `attn_backend` which is now used in qwen3 `model_registry` as a convenient way to switch config. - Thanks to that, we can support FlexAttention with configurable options, including `block_size` and `kernel_options` to enable Flash attention backend pytorch#2171. Added a config `qwen3_debugmodel_flex_flash` to the registry. - Move `annotate_flex_attention_for_regional_inductor` from general attention file to `graph_trainer` where it's used. - remove outdated artificial restrictions on CP / CP + TP - some naming improvement NOTE: - GraphTrainer tests fail on main as well - Model test on Llama4+PP+compile fails on main as well
for several things:
_ContextParallel(seq_dim=2, ...)toseq_dim=1-> I also had to change flux model impl to adapt.VLLMGQAttentionmodule added in that PR.inner_attentiona python config, instead of stringattn_backendwhich is now used in qwen3model_registryas a convenient way to switch config.block_sizeandkernel_optionsto enable Flash attention backend Add Flex flash backend to flex attention module #2171. Added a configqwen3_debugmodel_flex_flashto the registry.annotate_flex_attention_for_regional_inductorfrom general attention file tograph_trainerwhere it's used.NOTE: