Conversation
|
!test |
|
!test |
Greptile SummaryThis PR implements the TMA auto-transpose scheduler, adding an Key issues found:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[TransposeScheduler::computeHeuristics] --> B{TmaTranspose\nenabled?}
B -- No --> C[Non-TMA heuristics]
B -- Yes --> D[tma::getTransposeHeuristics]
D --> E{n_input > n_output?}
E -- Yes --> F[is_output_smem_transpose = true\nuse_tma_load = true\nuse_tma_store = true]
E -- No --> G[is_output_smem_transpose = false\nuse_tma_load = true\nuse_tma_store = false]
F & G --> H[Compute tile_size2 = 128B / dtype\nelements_per_chunk\nestimated_tile_size1]
H --> I[While chunks_per_thread < 4:\n tile_size1 *= 2]
I --> J[Return TransposeParams]
J --> K[tma::scheduleTranspose]
K --> L[cacheInputs + cacheAndForkOutputs]
L --> M{use_tma_load?}
M -- Yes --> N[Set CpAsyncBulkTensorTile on inputs\nMove to shared memory]
M --> O{use_tma_store?}
O -- Yes --> P[Set CpAsyncBulkTensorTile on outputs\nMove to shared memory]
N & P --> Q{is_output_smem_transpose?}
Q -- Yes --> R[Swizzle output smem\nNo swizzle on input smem]
Q -- No --> S[Swizzle input smem\nNo swizzle on output smem]
R & S --> T[Tile ref_tv, propagate transforms\nParallelize TIDx + BIDx\nVectorize smem reads]
T --> U[inlineMost]
|
csrc/scheduler/transpose_tma.cpp
Outdated
| NVF_ERROR(grouped_inputs_outputs.size() >= 2); | ||
|
|
||
| // When there are more inputs than outputs, output smem transpose should be | ||
| // used, however, if it is not, then input smem tranpose will be used, to |
There was a problem hiding this comment.
tranpose should be transpose
| const int64_t cta_per_sm = | ||
| dev_props->maxThreadsPerMultiProcessor / threads_per_cta; | ||
| const int64_t bytes_per_cta = bytes_per_sm / cta_per_sm; | ||
| const int64_t bytes_per_tile = bytes_per_cta / n_input; |
There was a problem hiding this comment.
Add check that n_input > 0 before this division. While the scheduler validation should prevent this, defensive programming would make the code more robust.
| const int64_t bytes_per_tile = bytes_per_cta / n_input; | |
| NVF_ERROR(n_input > 0, "Expected at least one TensorView input for transpose"); | |
| const int64_t bytes_per_tile = bytes_per_cta / n_input; |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Review updated until commit bc772db Description
|
| Relevant files |
|---|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Potential TMA load restriction
This is more restrictive than the original which checked all loop domains. This could potentially exclude valid TMA loads where some dimensions have extent 1 but other dimensions are parallelized with threads. Need to verify this doesn't break existing TMA use cases. |
Additional Comments (2)
If On an H100 (maxThreadsPerMultiProcessor = 2048, cta_per_sm = 8, bytes_per_cta = 8192), this triggers when
The |
Additional Comments (4)
If This happens when While unlikely for typical transpose fusions (1–2 inputs), this is an unbounded loop with no guard. A simple fix is to initialise
Note the asymmetry: Step 3 already guards the analogous constraint with an explicit
These Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
| if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) { | ||
| return id->isThreadDim() || | ||
| id->getParallelType() == ParallelType::Serial; | ||
| })) { |
There was a problem hiding this comment.
trivial optimization of multiple-tma loads, doesn't have to be in this PR.
|
|
||
| // When not using output smem transpose but inputs > outputs, swap groups | ||
| // so group 2 remains the swizzled side. | ||
| if (!tparams->is_output_smem_transpose && |
There was a problem hiding this comment.
This branch is not used in current heuristics, but may use it in future tuning.
|
!test |
2 similar comments
|
!test |
|
!test |
To reduce number of tranpose ops,
is_output_smem_transposeis added to control input/output transpose:1. When there are more inputs than outputs,
is_output_smem_transpose = TrueTMA load without swizzle, TMA store with swizzle, transpose at
regs --> output cached smem2. When there are less inputs than outputs,
is_output_smem_transpose = FalseTMA load with swizzle, register store, transpose at
input cached smem -> regsCurrent performance is in this doc.