Skip to content

add auto tma transpose scheduler#6018

Open
liqiangxl wants to merge 9 commits intomainfrom
llu/transpose_output_smem_auto
Open

add auto tma transpose scheduler#6018
liqiangxl wants to merge 9 commits intomainfrom
llu/transpose_output_smem_auto

Conversation

@liqiangxl
Copy link
Collaborator

To reduce number of tranpose ops, is_output_smem_transpose is added to control input/output transpose:

1. When there are more inputs than outputs, is_output_smem_transpose = True
TMA load without swizzle, TMA store with swizzle, transpose at regs --> output cached smem

2. When there are less inputs than outputs, is_output_smem_transpose = False
TMA load with swizzle, register store, transpose at input cached smem -> regs

Current performance is in this doc.

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl liqiangxl marked this pull request as ready for review February 27, 2026 15:40
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Greptile Summary

This PR implements the TMA auto-transpose scheduler, adding an is_output_smem_transpose flag that selects which side of shared memory carries the swizzled (transposed) layout, then wires up heuristic selection and TMA scheduling for both the load-side and store-side paths. Two supporting infrastructure changes accompany it: EnableOption::TmaTranspose is added so the TMA path is opt-in, and getBatchableTmaLoads is tightened to skip trivial size-1 loop-domain IDs that should not block TMA batching.

Key issues found:

  • Infinite loop in heuristics (transpose_tma.cpp line 89–103): estimated_tile_size1 = bytes_per_tile / kTmaSwizzleBytes truncates to 0 for fusions with many inputs (≥ 65 on H100), and 0 * 2 never grows, so the while loop hangs. Fix: clamp the initial estimate with std::max(int64_t(1), ...).
  • Out-of-range crash with inconsistent params (transpose_tma.cpp line 259): tma_store_tvs.at(0) is called inside if (tparams->is_output_smem_transpose), but tma_store_tvs is only populated when use_tma_store=true. The header comment explicitly states these two flags are independent, so the crash is reachable by design when overriding params directly (as done in TmaTransposeParamsTestP).
  • Debug std::cout in test (test_transpose.cpp line 1944): left-over print in OutputTransposeBankconflict will produce console noise on CI even for passing runs.

Confidence Score: 2/5

  • Not safe to merge — contains an infinite loop and a potential out-of-range crash in the new scheduler core.
  • Two logic bugs in transpose_tma.cpp block a safe merge: the infinite loop caused by estimated_tile_size1 = 0 will hang any fusion with a large number of inputs, and the unguarded tma_store_tvs.at(0) will crash with param combinations that the design explicitly documents as valid.
  • Pay close attention to csrc/scheduler/transpose_tma.cpp — both bugs are in this file.

Important Files Changed

Filename Overview
csrc/scheduler/transpose_tma.cpp Core implementation of TMA transpose heuristics and scheduling; contains two logic bugs: an infinite loop when estimated_tile_size1 truncates to 0, and a crash when is_output_smem_transpose=true but use_tma_store=false.
csrc/scheduler/transpose_heuristic.h Adds use_tma_store, is_output_smem_transpose, chunks_per_thread, and elements_per_chunk fields with correct updates to sameAs, toString, and hash; no issues found.
csrc/scheduler/transpose.cpp Gates TMA scheduler behind EnableOption::TmaTranspose flag and extends routing condition to cover use_tma_store; changes are correct and minimal.
csrc/options.h Adds TmaTranspose to EnableOption, pins all option enums to uint8_t underlying type, adds <cstdint> include, and refactors copy constructor to use member-initializer list; all changes are safe given current enum sizes.
csrc/options.cpp Registers "tma_transpose" string to EnableOption::TmaTranspose and modernizes std::sort to std::ranges::sort; straightforward and correct.
csrc/device_lower/analysis/tma.cpp Filters trivial (size-1) loop-domain IDs before checking for thread/serial parallelization in getBatchableTmaLoads, allowing TMA loads on tensors with padded unit-extent dimensions; logic is semantically correct.
tests/cpp/test_transpose.cpp Adds parameterized tests for TMA transpose across dtypes, dimension pairs, and explicit param combinations; good coverage, but a debug std::cout was left in OutputTransposeBankconflict.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[TransposeScheduler::computeHeuristics] --> B{TmaTranspose\nenabled?}
    B -- No --> C[Non-TMA heuristics]
    B -- Yes --> D[tma::getTransposeHeuristics]
    D --> E{n_input > n_output?}
    E -- Yes --> F[is_output_smem_transpose = true\nuse_tma_load = true\nuse_tma_store = true]
    E -- No --> G[is_output_smem_transpose = false\nuse_tma_load = true\nuse_tma_store = false]
    F & G --> H[Compute tile_size2 = 128B / dtype\nelements_per_chunk\nestimated_tile_size1]
    H --> I[While chunks_per_thread < 4:\n  tile_size1 *= 2]
    I --> J[Return TransposeParams]
    J --> K[tma::scheduleTranspose]
    K --> L[cacheInputs + cacheAndForkOutputs]
    L --> M{use_tma_load?}
    M -- Yes --> N[Set CpAsyncBulkTensorTile on inputs\nMove to shared memory]
    M --> O{use_tma_store?}
    O -- Yes --> P[Set CpAsyncBulkTensorTile on outputs\nMove to shared memory]
    N & P --> Q{is_output_smem_transpose?}
    Q -- Yes --> R[Swizzle output smem\nNo swizzle on input smem]
    Q -- No --> S[Swizzle input smem\nNo swizzle on output smem]
    R & S --> T[Tile ref_tv, propagate transforms\nParallelize TIDx + BIDx\nVectorize smem reads]
    T --> U[inlineMost]
Loading

Comments Outside Diff (3)

  1. csrc/scheduler/transpose_tma.cpp, line 89-103 (link)

    Infinite loop when estimated_tile_size1 evaluates to zero

    estimated_tile_size1 is computed via integer division, so when bytes_per_tile < kTmaSwizzleBytes (128), the result truncates to 0. Because 0 * 2 == 0, the while loop at line 101 never makes progress and hangs forever.

    This is reachable in practice. On an H100 (maxThreadsPerMultiProcessor = 2048):

    • cta_per_sm = 2048 / 256 = 8
    • bytes_per_cta = 65536 / 8 = 8192
    • With n_input = 128: bytes_per_tile = 8192 / 128 = 64, so estimated_tile_size1 = 64 / 128 = 0

    Fusions with a large number of inputs (≥ 65 on H100) will trigger this. The fix is to clamp the initial estimate to at least 1:

  2. csrc/scheduler/transpose_tma.cpp, line 256-265 (link)

    tma_store_tvs.at(0) crashes when is_output_smem_transpose=true but use_tma_store=false

    tma_store_tvs is only populated inside the if (tparams->use_tma_store) block (lines 162–171). If is_output_smem_transpose is true while use_tma_store is false, tma_store_tvs will be empty and tma_store_tvs.at(0) will throw std::out_of_range.

    The header comment for these two fields explicitly states they are independent:

    "This is independent of use_tma_load/use_tma_store — TMA can be used for either side regardless of where the transpose swizzle lives."

    So the design intent allows this combination, but the implementation crashes on it. A minimal guard would prevent the out-of-range access and surface a clearer error:

    if (tparams->is_output_smem_transpose) {
        NVF_ERROR(
            !tma_store_tvs.empty(),
            "is_output_smem_transpose requires use_tma_store to populate tma_store_tvs");
        MmaInputSmemSwizzle swizzle =
            mma_utils::tmaSwizzleSharedMemory(tma_store_tvs.at(0));
  3. tests/cpp/test_transpose.cpp, line 1944-1946 (link)

    Debug std::cout left in test

    This std::cout will pollute test output on CI for every run where a bank conflict happens. Since EXPECT_TRUE(bank_conflicts.empty()) already flags the failure, the print is redundant in a passing run and noisy in a failing one. Consider removing it or routing it through gtest's logging:

Last reviewed commit: f722d67

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

NVF_ERROR(grouped_inputs_outputs.size() >= 2);

// When there are more inputs than outputs, output smem transpose should be
// used, however, if it is not, then input smem tranpose will be used, to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tranpose should be transpose

const int64_t cta_per_sm =
dev_props->maxThreadsPerMultiProcessor / threads_per_cta;
const int64_t bytes_per_cta = bytes_per_sm / cta_per_sm;
const int64_t bytes_per_tile = bytes_per_cta / n_input;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check that n_input > 0 before this division. While the scheduler validation should prevent this, defensive programming would make the code more robust.

Suggested change
const int64_t bytes_per_tile = bytes_per_cta / n_input;
NVF_ERROR(n_input > 0, "Expected at least one TensorView input for transpose");
const int64_t bytes_per_tile = bytes_per_cta / n_input;

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@liqiangxl liqiangxl requested a review from rdspring1 February 27, 2026 17:24
@github-actions
Copy link

github-actions bot commented Mar 2, 2026

Review updated until commit bc772db

Description

  • Implements automatic TMA (Tensor Memory Access) transpose scheduler with two paths: input smem transpose (swizzle on input) and output smem transpose (swizzle on output)

  • Adds new TmaTranspose enable option to toggle the feature; scheduler falls back to non-TMA when disabled

  • Introduces new parameters: use_tma_store, is_output_smem_transpose, chunks_per_thread, elements_per_chunk for flexible TMA configuration

  • Adds comprehensive tests covering different dtypes, transpose dimensions, and TMA parameter combinations

Changes walkthrough

Relevant files

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Potential TMA load restriction

The new code filters loop domains to only include non-trivial IDs (extent > 1 or non-const) before checking for thread/serial dims.
This is more restrictive than the original which checked all loop domains. This could potentially exclude valid TMA loads
where some dimensions have extent 1 but other dimensions are parallelized with threads. Need to verify this doesn't break
existing TMA use cases.

auto non_trivial_ids =
    tv->getLoopDomain() | std::views::filter([](const IterDomain* id) {
      return !id->extent()->isConstScalar() ||
          id->extent()->evaluate().as<int64_t>() > 1;
    });
if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) {
      return id->isThreadDim() ||
          id->getParallelType() == ParallelType::Serial;
    })) {
  return {};
}
Missing null check

In scheduleTranspose, when setting up TMA store (lines 165-172), the code accesses fusion->outputs()[output_idx] without
checking if output_idx is within bounds. While cached_outputs should correspond to outputs, a bounds check would be safer.

for (auto [cached_output, output_idx] : cached_outputs) {
  auto output = fusion->outputs()[output_idx]->as<TensorView>();
  output->definition()->as<LoadStoreOp>()->setOpType(
      LoadStoreOpType::CpAsyncBulkTensorTile);
  cached_output->setMemoryType(MemoryType::Shared);
  cached_output->cacheBefore();
  tma_store_tvs.push_back(cached_output);
}
Thread safety consideration

The copy constructor was modified to use a lambda that captures other.mutex_ and returns other.options_. While this appears
correct, the original implementation directly assigned options_. The new approach should be verified to maintain the same
thread-safety semantics under concurrent access patterns.

Options(const Options& other)
    : options_([&other]() {
        std::lock_guard<std::mutex> lock_other(other.mutex_);
        return other.options_;
      }()) {}

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (2)

csrc/scheduler/transpose_tma.cpp, line 106
Infinite loop when estimated_tile_size1 starts at zero

If bytes_per_tile < kTmaSwizzleBytes (line 91-92), integer division yields estimated_tile_size1 = 0. The while loop (line 104) then spins forever because 0 * 2 == 0 and get_chunks_per_thread() (line 98-102) stays at 0, which is always less than min_chunks_per_thread = 4.

On an H100 (maxThreadsPerMultiProcessor = 2048, cta_per_sm = 8, bytes_per_cta = 8192), this triggers when n_input > 64. Add an initialization guard before the loop:

  // Ensure we start from at least 1 to avoid multiplying 0 forever.
  if (estimated_tile_size1 == 0) {
    estimated_tile_size1 = 1;
  }
  while (get_chunks_per_thread() < min_chunks_per_thread) {
    estimated_tile_size1 *= 2;
  }

tests/cpp/test_transpose.cpp, line 1947
Unconditional debug output will pollute test logs

The std::cout block (lines 1945–1947) prints every bank conflict unconditionally. This makes test runner output noisy, especially since the BFloat16 path is expected to have bank conflicts. Consider wrapping the print in a debug flag or removing it:

      if (auto* ke = dynamic_cast<KernelExecutor*>(executor.get())) {
        auto bank_conflicts = getBankConflictInfo(ke->compiledKernel()->kernel());
        if (dtype == DataType::Float) {
          EXPECT_TRUE(bank_conflicts.empty());
        } else {
          // TODO: update to EXPECT_TRUE once bf16 bank conflicts are resolved.
          EXPECT_FALSE(bank_conflicts.empty());
        }
      }

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (4)

csrc/scheduler/transpose_tma.cpp, line 107
Potential infinite loop when estimated_tile_size1 initializes to zero

If bytes_per_tile < kTmaSwizzleBytes (128), integer division yields estimated_tile_size1 = 0. The while loop then evaluates get_chunks_per_thread() as 0 (because the numerator is 0 * tile_size2 = 0) and multiplies: 0 * 2 = 0 — the loop never terminates.

This happens when bytes_per_cta / n_input < 128. With an SM90 GPU (maxThreadsPerMultiProcessor = 2048), cta_per_sm = 8, giving bytes_per_cta = 8192. So the loop infinite-hangs when n_input > 64.

While unlikely for typical transpose fusions (1–2 inputs), this is an unbounded loop with no guard. A simple fix is to initialise estimated_tile_size1 to at least 1:

int64_t estimated_tile_size1 =
    std::max(int64_t(1), bytes_per_tile / kTmaSwizzleBytes);

csrc/scheduler/transpose_tma.cpp, line 267
Missing guard before accessing tma_store_tvs when use_tma_store may be false

tma_store_tvs is only populated when tparams->use_tma_store == true (lines 164–173), but this block checks only tparams->is_output_smem_transpose. If is_output_smem_transpose = true but use_tma_store = false, then tma_store_tvs will be empty and .at(0) throws std::out_of_range.

Note the asymmetry: Step 3 already guards the analogous constraint with an explicit NVF_ERROR(tparams->use_tma_load, ...) at line 286-288. Adding the same guard here would be consistent:

if (tparams->is_output_smem_transpose) {
    NVF_ERROR(
        tparams->use_tma_store,
        "TMA store must be used when output smem is transposed");
    MmaInputSmemSwizzle swizzle =
        mma_utils::tmaSwizzleSharedMemory(tma_store_tvs.at(0));

tests/cpp/test_transpose.cpp, line 1949
Debug std::cout in test code — use GTest facilities instead

These std::cout lines will only fire when bank conflicts are detected (when the test is already failing). However, raw std::cout in tests is unconventional — GTest's ADD_FAILURE() / SCOPED_TRACE or just the EXPECT_TRUE failure message would be more idiomatic:

      for (auto& [expr, ways] : bank_conflicts) {
        auto [read_ways, write_ways] = ways;
        ADD_FAILURE() << "Bank conflict in: " << expr->toString()
                      << "  read=" << read_ways << "-way"
                      << ", write=" << write_ways << "-way";
      }

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


tests/cpp/test_transpose.cpp, line 1969
Typo "tranapose" should be "transpose" in multiple lines

// Test different combinations of TMA transpose parameters:
// (is_output_smem, use_tma_load, use_tma_store)
//   (false, true, false)  - input smem transpose, TMA load only
//   (false, true, true)   - input smem transpose, TMA load + TMA store
//   (true,  true, true)   - output smem transpose, TMA load + TMA store
//   (true,  false, true)  - output smem transpose, TMA store only

if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) {
return id->isThreadDim() ||
id->getParallelType() == ParallelType::Serial;
})) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trivial optimization of multiple-tma loads, doesn't have to be in this PR.


// When not using output smem transpose but inputs > outputs, swap groups
// so group 2 remains the swizzled side.
if (!tparams->is_output_smem_transpose &&
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is not used in current heuristics, but may use it in future tuning.

@liqiangxl
Copy link
Collaborator Author

!test

2 similar comments
@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant