From 3663d4cec7b9e8d22491229905de780837facf3f Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Fri, 11 Apr 2025 10:43:29 -0700 Subject: [PATCH 1/3] rename grid_swizzle_factor to grid_traversal_factor --- csrc/python_frontend/python_bindings.cpp | 2 +- csrc/scheduler/ampere_multi_matmul.cpp | 4 ++-- csrc/scheduler/ampere_multi_matmul.h | 2 +- csrc/scheduler/hopper_multi_matmul.cpp | 4 ++-- csrc/scheduler/hopper_multi_matmul.h | 2 +- csrc/scheduler/matmul_heuristic.h | 10 +++++----- csrc/scheduler/matmul_heuristic_plugin.cpp | 4 ++-- csrc/scheduler/matmul_utils.cpp | 8 ++++---- tests/cpp/test_matmul.cpp | 4 ++-- 9 files changed, 20 insertions(+), 20 deletions(-) diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 4a438260422..123eec51263 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -786,7 +786,7 @@ void defineHeuristicParamBindings(py::module& nvfuser) { .PARAM(MatmulParams, circular_buffer_options) .PARAM(MatmulParams, supported_vec_size) .PARAM(MatmulParams, async_gmem_load_operands) - .PARAM(MatmulParams, grid_swizzle_factor) + .PARAM(MatmulParams, grid_traversal_factor) .PARAM(MatmulParams, use_smem_epilogue) .PARAM(MatmulParams, promote_prologue_smem_reuse) .PARAM(MatmulParams, splitk_factor) diff --git a/csrc/scheduler/ampere_multi_matmul.cpp b/csrc/scheduler/ampere_multi_matmul.cpp index a001ab10d8b..c5c73edf6b9 100644 --- a/csrc/scheduler/ampere_multi_matmul.cpp +++ b/csrc/scheduler/ampere_multi_matmul.cpp @@ -535,7 +535,7 @@ void AmpereMultipleMatmulScheduler::cacheOperandsToRegisters( void AmpereMultipleMatmulScheduler::swizzleBlockTiles( TensorView* tv, std::vector& outer_dim_roles) { - if (params_->grid_swizzle_factor != 1) { + if (params_->grid_traversal_factor != 1) { // Find position of outer M and N dims in schedule_.tiled int64_t Mo_pos = -1, No_pos = -1; for (size_t i : arange(outer_dim_roles.size())) { @@ -546,7 +546,7 @@ void AmpereMultipleMatmulScheduler::swizzleBlockTiles( } } - int factor = std::max(1, params_->grid_swizzle_factor); // must be >=1 + int factor = std::max(1, params_->grid_traversal_factor); // must be >=1 switch (params_->cta_order) { case MatmulParams::TileRasterizationOrder::RowMajor: // split [I1, I2/factor, factor] diff --git a/csrc/scheduler/ampere_multi_matmul.h b/csrc/scheduler/ampere_multi_matmul.h index a7ec8e60a13..b0db75849c9 100644 --- a/csrc/scheduler/ampere_multi_matmul.h +++ b/csrc/scheduler/ampere_multi_matmul.h @@ -150,7 +150,7 @@ class AmpereMultipleMatmulScheduler : public MultipleMatmulScheduler { //! with the same role will be merged. //! 2) After that, we perform splits according to //! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki]. - //! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has + //! 3) Depending on the value of params_->grid_traversal_factor, if the TV has //! both M and N dimensions, we perform a 2D swizzle of the outer dimensions //! Mo and No. //! 4) Finally, we do a split-K split if the splitk_factor is not 1 diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index 479d91cc339..f0cdc0b5231 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -172,7 +172,7 @@ void HopperMultipleMatmulScheduler::run() { void HopperMultipleMatmulScheduler::swizzleBlockTiles( TensorView* tv, std::vector& outer_dim_roles) { - if (params_->grid_swizzle_factor != 1) { + if (params_->grid_traversal_factor != 1) { // Find position of outer M and N dims in schedule_.tiled int64_t Mo_pos = -1, No_pos = -1; for (size_t i : arange(outer_dim_roles.size())) { @@ -183,7 +183,7 @@ void HopperMultipleMatmulScheduler::swizzleBlockTiles( } } - int factor = std::max(1, params_->grid_swizzle_factor); // must be >=1 + int factor = std::max(1, params_->grid_traversal_factor); // must be >=1 switch (params_->cta_order) { case MatmulParams::TileRasterizationOrder::RowMajor: // split [I1, I2/factor, factor] diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 524b7dbf535..59bb8758742 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -138,7 +138,7 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { //! with the same role will be merged. //! 2) After that, we perform splits according to //! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki]. - //! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has + //! 3) Depending on the value of params_->grid_traversal_factor, if the TV has //! both M and N dimensions, we perform a 2D swizzle of the outer dimensions //! Mo and No. //! 4) Finally, we do a split-K split if the splitk_factor is not 1 diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index cdd8b78fec1..8bc0f54dd90 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -296,12 +296,12 @@ class MatmulParams : public HeuristicParams { //! will more likely be forming sub-tiles of the C matrix. This will increase //! L2 hit rate/data reuse of A and B. //! - //! Eg for grid_swizzle_factor=2: + //! Eg for grid_traversal_factor=2: //! A1 A2 B1 B2 --> A1 A2 A3 A4 B1 B2 B3 B4 //! A3 A4 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4 //! C1 C2 D1 D2 //! C3 C4 D3 D4 - int grid_swizzle_factor = 1; + int grid_traversal_factor = 1; //! Unswizzle MMA results in shared memory to get //! coalesced write to global memory @@ -370,7 +370,7 @@ class MatmulParams : public HeuristicParams { << ((cta_order == TileRasterizationOrder::RowMajor) ? "row-major" : "column-major") << "\n" - << "Grid swizzle factor: " << grid_swizzle_factor << "\n"; + << "Grid swizzle factor: " << grid_traversal_factor << "\n"; ss << "Tiling strategy: "; switch (tiling_strategy) { case TilingStrategy::OneTilePerCTA: @@ -425,7 +425,7 @@ class MatmulParams : public HeuristicParams { (circular_buffer_options.hash() << 2) ^ (nvfuser::hash(tile_sizes) << 3) ^ (std::hash{}(static_cast(cta_order)) << 4) ^ - (std::hash{}(grid_swizzle_factor) << 5) ^ + (std::hash{}(grid_traversal_factor) << 5) ^ (std::hash{}(splitk_factor) << 6); return attr_hash; } @@ -442,7 +442,7 @@ class MatmulParams : public HeuristicParams { other->circular_buffer_options == circular_buffer_options && other->supported_vec_size == supported_vec_size && other->cta_order == cta_order && - other->grid_swizzle_factor == grid_swizzle_factor && + other->grid_traversal_factor == grid_traversal_factor && other->use_smem_epilogue == use_smem_epilogue && other->promote_prologue_smem_reuse == promote_prologue_smem_reuse && other->splitk_factor == splitk_factor; diff --git a/csrc/scheduler/matmul_heuristic_plugin.cpp b/csrc/scheduler/matmul_heuristic_plugin.cpp index 36ec123f220..9ad0522c7e1 100644 --- a/csrc/scheduler/matmul_heuristic_plugin.cpp +++ b/csrc/scheduler/matmul_heuristic_plugin.cpp @@ -143,7 +143,7 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) { config->cluster_dims[1] = mparams->cluster_dims.y; config->cluster_dims[2] = mparams->cluster_dims.z; config->splitk_factor = mparams->splitk_factor; - config->grid_swizzle_factor = mparams->grid_swizzle_factor; + config->grid_swizzle_factor = mparams->grid_traversal_factor.first; config->cta_order = mparams->cta_order == MatmulParams::TileRasterizationOrder::RowMajor ? 0 : 1; @@ -179,7 +179,7 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) { menc.k = config->instruction_tile[2]; mparams->mma_macro = menc; // cast back to uint64_t mparams->splitk_factor = config->splitk_factor; - mparams->grid_swizzle_factor = config->grid_swizzle_factor; + mparams->grid_traversal_factor = config->grid_traversal_factor; switch (config->cta_order) { case 0: mparams->cta_order = MatmulParams::TileRasterizationOrder::RowMajor; diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index c7730472b0f..a58ea3fec40 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -403,15 +403,15 @@ bool fillDefaultHopperHeuristic( // We also swizzle the tiles as much as possible up to 16 tiles. Like choosing // the rasterization order, this is used to increase L2 locality - mparams->grid_swizzle_factor = std::min(swizzled_tiles, 16L); - while (swizzled_tiles % mparams->grid_swizzle_factor != 0) { + mparams->grid_traversal_factor = std::min(swizzled_tiles, 16L); + while (swizzled_tiles % mparams->grid_traversal_factor != 0) { // Decrease the swizzle factor if it would result in nondivisible splits, // since this would unnecessarily increase the grid size. - mparams->grid_swizzle_factor--; + mparams->grid_traversal_factor--; } // TODO: grid swizzling is currently disabled on Hopper since we cannot // properly inline when we swizzle unmapped loop broadcasts - mparams->grid_swizzle_factor = 1L; + mparams->grid_traversal_factor = 1L; // TODO: Finally, we set the CGA size diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index eb96b8917ce..6f4acdae955 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -526,7 +526,7 @@ TEST_P(MatmulTestWithLayout, AmpereSwizzle) { mparams.circular_buffer_options.smem_circular_buffer_stage = 3; mparams.cta_order = order; - mparams.grid_swizzle_factor = swizzle; + mparams.grid_traversal_factor = swizzle; SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) ->schedule(&fusion, &mparams); @@ -5016,7 +5016,7 @@ TEST_F(HopperMatmulTest, MLPGemmPersistentBroadcastInputs) { MatmulParams::TilingStrategy::DistributeTilesAcrossSMs; mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = false; - mparams.grid_swizzle_factor = 8; + mparams.grid_traversal_factor = 8; // TODO reduced share memory aliasing because of persistent scheduling mparams.circular_buffer_options.smem_circular_buffer_stage = 3; mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; From c13ba4e83b9f7f3bf627f76a7581cad98222e67b Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Fri, 11 Apr 2025 10:53:41 -0700 Subject: [PATCH 2/3] convert to std::pair --- csrc/scheduler/ampere_multi_matmul.cpp | 5 +++-- csrc/scheduler/ampere_multi_matmul.h | 3 ++- csrc/scheduler/hopper_multi_matmul.cpp | 5 +++-- csrc/scheduler/hopper_multi_matmul.h | 3 ++- csrc/scheduler/matmul_heuristic.h | 9 +++++---- csrc/scheduler/matmul_heuristic_plugin.cpp | 2 +- csrc/scheduler/matmul_utils.cpp | 11 +++++------ tests/cpp/test_matmul.cpp | 4 ++-- 8 files changed, 23 insertions(+), 19 deletions(-) diff --git a/csrc/scheduler/ampere_multi_matmul.cpp b/csrc/scheduler/ampere_multi_matmul.cpp index c5c73edf6b9..bac4896a339 100644 --- a/csrc/scheduler/ampere_multi_matmul.cpp +++ b/csrc/scheduler/ampere_multi_matmul.cpp @@ -535,7 +535,7 @@ void AmpereMultipleMatmulScheduler::cacheOperandsToRegisters( void AmpereMultipleMatmulScheduler::swizzleBlockTiles( TensorView* tv, std::vector& outer_dim_roles) { - if (params_->grid_traversal_factor != 1) { + if (params_->grid_traversal_factor.first != 1) { // Find position of outer M and N dims in schedule_.tiled int64_t Mo_pos = -1, No_pos = -1; for (size_t i : arange(outer_dim_roles.size())) { @@ -546,7 +546,8 @@ void AmpereMultipleMatmulScheduler::swizzleBlockTiles( } } - int factor = std::max(1, params_->grid_traversal_factor); // must be >=1 + int factor = + std::max(1, params_->grid_traversal_factor.first); // must be >=1 switch (params_->cta_order) { case MatmulParams::TileRasterizationOrder::RowMajor: // split [I1, I2/factor, factor] diff --git a/csrc/scheduler/ampere_multi_matmul.h b/csrc/scheduler/ampere_multi_matmul.h index b0db75849c9..669ad887f30 100644 --- a/csrc/scheduler/ampere_multi_matmul.h +++ b/csrc/scheduler/ampere_multi_matmul.h @@ -150,7 +150,8 @@ class AmpereMultipleMatmulScheduler : public MultipleMatmulScheduler { //! with the same role will be merged. //! 2) After that, we perform splits according to //! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki]. - //! 3) Depending on the value of params_->grid_traversal_factor, if the TV has + //! 3) Depending on the value of params_->grid_traversal_factor, if the TV + //! has //! both M and N dimensions, we perform a 2D swizzle of the outer dimensions //! Mo and No. //! 4) Finally, we do a split-K split if the splitk_factor is not 1 diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index f0cdc0b5231..a0dc4fc2b1f 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -172,7 +172,7 @@ void HopperMultipleMatmulScheduler::run() { void HopperMultipleMatmulScheduler::swizzleBlockTiles( TensorView* tv, std::vector& outer_dim_roles) { - if (params_->grid_traversal_factor != 1) { + if (params_->grid_traversal_factor.first != 1) { // Find position of outer M and N dims in schedule_.tiled int64_t Mo_pos = -1, No_pos = -1; for (size_t i : arange(outer_dim_roles.size())) { @@ -183,7 +183,8 @@ void HopperMultipleMatmulScheduler::swizzleBlockTiles( } } - int factor = std::max(1, params_->grid_traversal_factor); // must be >=1 + int factor = + std::max(1, params_->grid_traversal_factor.first); // must be >=1 switch (params_->cta_order) { case MatmulParams::TileRasterizationOrder::RowMajor: // split [I1, I2/factor, factor] diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 59bb8758742..0df9c959456 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -138,7 +138,8 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { //! with the same role will be merged. //! 2) After that, we perform splits according to //! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki]. - //! 3) Depending on the value of params_->grid_traversal_factor, if the TV has + //! 3) Depending on the value of params_->grid_traversal_factor, if the TV + //! has //! both M and N dimensions, we perform a 2D swizzle of the outer dimensions //! Mo and No. //! 4) Finally, we do a split-K split if the splitk_factor is not 1 diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index 8bc0f54dd90..52baa1f2da0 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -296,12 +296,12 @@ class MatmulParams : public HeuristicParams { //! will more likely be forming sub-tiles of the C matrix. This will increase //! L2 hit rate/data reuse of A and B. //! - //! Eg for grid_traversal_factor=2: + //! Eg for grid_traversal_factor = {2, 1}: //! A1 A2 B1 B2 --> A1 A2 A3 A4 B1 B2 B3 B4 //! A3 A4 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4 //! C1 C2 D1 D2 //! C3 C4 D3 D4 - int grid_traversal_factor = 1; + std::pair grid_traversal_factor = {1, 1}; //! Unswizzle MMA results in shared memory to get //! coalesced write to global memory @@ -425,8 +425,9 @@ class MatmulParams : public HeuristicParams { (circular_buffer_options.hash() << 2) ^ (nvfuser::hash(tile_sizes) << 3) ^ (std::hash{}(static_cast(cta_order)) << 4) ^ - (std::hash{}(grid_traversal_factor) << 5) ^ - (std::hash{}(splitk_factor) << 6); + (std::hash{}(grid_traversal_factor.first) << 5) ^ + (std::hash{}(grid_traversal_factor.second) << 6) ^ + (std::hash{}(splitk_factor) << 7); return attr_hash; } diff --git a/csrc/scheduler/matmul_heuristic_plugin.cpp b/csrc/scheduler/matmul_heuristic_plugin.cpp index 9ad0522c7e1..3879147ed29 100644 --- a/csrc/scheduler/matmul_heuristic_plugin.cpp +++ b/csrc/scheduler/matmul_heuristic_plugin.cpp @@ -179,7 +179,7 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) { menc.k = config->instruction_tile[2]; mparams->mma_macro = menc; // cast back to uint64_t mparams->splitk_factor = config->splitk_factor; - mparams->grid_traversal_factor = config->grid_traversal_factor; + mparams->grid_traversal_factor.first = config->grid_swizzle_factor; switch (config->cta_order) { case 0: mparams->cta_order = MatmulParams::TileRasterizationOrder::RowMajor; diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index a58ea3fec40..e65d571ed76 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -403,15 +403,14 @@ bool fillDefaultHopperHeuristic( // We also swizzle the tiles as much as possible up to 16 tiles. Like choosing // the rasterization order, this is used to increase L2 locality - mparams->grid_traversal_factor = std::min(swizzled_tiles, 16L); - while (swizzled_tiles % mparams->grid_traversal_factor != 0) { + int grid_traversal_factor = std::min(swizzled_tiles, 16L); + while (swizzled_tiles % grid_traversal_factor != 0) { // Decrease the swizzle factor if it would result in nondivisible splits, // since this would unnecessarily increase the grid size. - mparams->grid_traversal_factor--; + grid_traversal_factor--; } - // TODO: grid swizzling is currently disabled on Hopper since we cannot - // properly inline when we swizzle unmapped loop broadcasts - mparams->grid_traversal_factor = 1L; + // TODO: Use only 1D grid traversal factor for now + mparams->grid_traversal_factor = {grid_traversal_factor, 1}; // TODO: Finally, we set the CGA size diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 6f4acdae955..6c5e738e542 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -526,7 +526,7 @@ TEST_P(MatmulTestWithLayout, AmpereSwizzle) { mparams.circular_buffer_options.smem_circular_buffer_stage = 3; mparams.cta_order = order; - mparams.grid_traversal_factor = swizzle; + mparams.grid_traversal_factor = {swizzle, 1}; SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) ->schedule(&fusion, &mparams); @@ -5016,7 +5016,7 @@ TEST_F(HopperMatmulTest, MLPGemmPersistentBroadcastInputs) { MatmulParams::TilingStrategy::DistributeTilesAcrossSMs; mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = false; - mparams.grid_traversal_factor = 8; + mparams.grid_traversal_factor = {8, 1}; // TODO reduced share memory aliasing because of persistent scheduling mparams.circular_buffer_options.smem_circular_buffer_stage = 3; mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; From fb5a8f237ad646731a7eccade968868ff8679ed2 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Fri, 11 Apr 2025 11:20:52 -0700 Subject: [PATCH 3/3] rename swizzleBlockTiles to reorderBlockTileTraversal --- csrc/scheduler/ampere_multi_matmul.cpp | 7 +++++-- csrc/scheduler/ampere_multi_matmul.h | 2 +- csrc/scheduler/hopper_multi_matmul.cpp | 7 +++++-- csrc/scheduler/hopper_multi_matmul.h | 2 +- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/csrc/scheduler/ampere_multi_matmul.cpp b/csrc/scheduler/ampere_multi_matmul.cpp index bac4896a339..4b691a6413b 100644 --- a/csrc/scheduler/ampere_multi_matmul.cpp +++ b/csrc/scheduler/ampere_multi_matmul.cpp @@ -532,9 +532,12 @@ void AmpereMultipleMatmulScheduler::cacheOperandsToRegisters( } } -void AmpereMultipleMatmulScheduler::swizzleBlockTiles( +void AmpereMultipleMatmulScheduler::reorderBlockTileTraversal( TensorView* tv, std::vector& outer_dim_roles) { + NVF_ERROR( + params_->grid_traversal_factor.second == 1, + "Ampere matmul scheduler does not support 2d grid traversal"); if (params_->grid_traversal_factor.first != 1) { // Find position of outer M and N dims in schedule_.tiled int64_t Mo_pos = -1, No_pos = -1; @@ -691,7 +694,7 @@ std::vector> AmpereMultipleMatmulScheduler:: // scheduling is the next step in this modernization. mma_utils::makeTile(tv, params_->tile_sizes.cta_tile, merged_roles); - swizzleBlockTiles(tv, merged_roles); + reorderBlockTileTraversal(tv, merged_roles); all_merged_roles.push_back(merged_roles); diff --git a/csrc/scheduler/ampere_multi_matmul.h b/csrc/scheduler/ampere_multi_matmul.h index 669ad887f30..785132e6b3c 100644 --- a/csrc/scheduler/ampere_multi_matmul.h +++ b/csrc/scheduler/ampere_multi_matmul.h @@ -132,7 +132,7 @@ class AmpereMultipleMatmulScheduler : public MultipleMatmulScheduler { //! This updates outer_dim_roles if we introduce a new dimension, which can //! happen if tv is missing a merged axis, in which case we skip merging after //! the split. This is analogous to forwarding during transform propagation. - void swizzleBlockTiles( + void reorderBlockTileTraversal( TensorView* tv, std::vector& outer_dim_roles); diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index a0dc4fc2b1f..be13487e0f1 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -169,9 +169,12 @@ void HopperMultipleMatmulScheduler::run() { setUpCircularBuffering(); } -void HopperMultipleMatmulScheduler::swizzleBlockTiles( +void HopperMultipleMatmulScheduler::reorderBlockTileTraversal( TensorView* tv, std::vector& outer_dim_roles) { + NVF_ERROR( + params_->grid_traversal_factor.second == 1, + "Hopper matmul scheduler does not support 2d grid traversal"); if (params_->grid_traversal_factor.first != 1) { // Find position of outer M and N dims in schedule_.tiled int64_t Mo_pos = -1, No_pos = -1; @@ -328,7 +331,7 @@ std::vector> HopperMultipleMatmulScheduler:: // scheduling is the next step in this modernization. mma_utils::makeTile(tv, params_->tile_sizes.cta_tile, merged_roles); - swizzleBlockTiles(tv, merged_roles); + reorderBlockTileTraversal(tv, merged_roles); all_merged_roles.push_back(merged_roles); diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 0df9c959456..854d0705234 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -120,7 +120,7 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { //! This updates outer_dim_roles if we introduce a new dimension, which can //! happen if tv is missing a merged axis, in which case we skip merging after //! the split. This is analogous to forwarding during transform propagation. - void swizzleBlockTiles( + void reorderBlockTileTraversal( TensorView* tv, std::vector& outer_dim_roles);