Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ void defineHeuristicParamBindings(py::module& nvfuser) {
.PARAM(MatmulParams, circular_buffer_options)
.PARAM(MatmulParams, supported_vec_size)
.PARAM(MatmulParams, async_gmem_load_operands)
.PARAM(MatmulParams, grid_swizzle_factor)
.PARAM(MatmulParams, grid_traversal_factor)
.PARAM(MatmulParams, use_smem_epilogue)
.PARAM(MatmulParams, promote_prologue_smem_reuse)
.PARAM(MatmulParams, splitk_factor)
Expand Down
12 changes: 8 additions & 4 deletions csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,13 @@ void AmpereMultipleMatmulScheduler::cacheOperandsToRegisters(
}
}

void AmpereMultipleMatmulScheduler::swizzleBlockTiles(
void AmpereMultipleMatmulScheduler::reorderBlockTileTraversal(
TensorView* tv,
std::vector<MatmulDimRole>& outer_dim_roles) {
if (params_->grid_swizzle_factor != 1) {
NVF_ERROR(
params_->grid_traversal_factor.second == 1,
"Ampere matmul scheduler does not support 2d grid traversal");
if (params_->grid_traversal_factor.first != 1) {
// Find position of outer M and N dims in schedule_.tiled
int64_t Mo_pos = -1, No_pos = -1;
for (size_t i : arange(outer_dim_roles.size())) {
Expand All @@ -546,7 +549,8 @@ void AmpereMultipleMatmulScheduler::swizzleBlockTiles(
}
}

int factor = std::max(1, params_->grid_swizzle_factor); // must be >=1
int factor =
std::max(1, params_->grid_traversal_factor.first); // must be >=1
switch (params_->cta_order) {
case MatmulParams::TileRasterizationOrder::RowMajor:
// split [I1, I2/factor, factor]
Expand Down Expand Up @@ -690,7 +694,7 @@ std::vector<std::vector<MatmulDimRole>> AmpereMultipleMatmulScheduler::
// scheduling is the next step in this modernization.
mma_utils::makeTile(tv, params_->tile_sizes.cta_tile, merged_roles);

swizzleBlockTiles(tv, merged_roles);
reorderBlockTileTraversal(tv, merged_roles);

all_merged_roles.push_back(merged_roles);

Expand Down
5 changes: 3 additions & 2 deletions csrc/scheduler/ampere_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class AmpereMultipleMatmulScheduler : public MultipleMatmulScheduler {
//! This updates outer_dim_roles if we introduce a new dimension, which can
//! happen if tv is missing a merged axis, in which case we skip merging after
//! the split. This is analogous to forwarding during transform propagation.
void swizzleBlockTiles(
void reorderBlockTileTraversal(
TensorView* tv,
std::vector<MatmulDimRole>& outer_dim_roles);

Expand All @@ -150,7 +150,8 @@ class AmpereMultipleMatmulScheduler : public MultipleMatmulScheduler {
//! with the same role will be merged.
//! 2) After that, we perform splits according to
//! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki].
//! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has
//! 3) Depending on the value of params_->grid_traversal_factor, if the TV
//! has
//! both M and N dimensions, we perform a 2D swizzle of the outer dimensions
//! Mo and No.
//! 4) Finally, we do a split-K split if the splitk_factor is not 1
Expand Down
12 changes: 8 additions & 4 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,13 @@ void HopperMultipleMatmulScheduler::run() {
setUpCircularBuffering();
}

void HopperMultipleMatmulScheduler::swizzleBlockTiles(
void HopperMultipleMatmulScheduler::reorderBlockTileTraversal(
TensorView* tv,
std::vector<MatmulDimRole>& outer_dim_roles) {
if (params_->grid_swizzle_factor != 1) {
NVF_ERROR(
params_->grid_traversal_factor.second == 1,
"Hopper matmul scheduler does not support 2d grid traversal");
if (params_->grid_traversal_factor.first != 1) {
// Find position of outer M and N dims in schedule_.tiled
int64_t Mo_pos = -1, No_pos = -1;
for (size_t i : arange(outer_dim_roles.size())) {
Expand All @@ -183,7 +186,8 @@ void HopperMultipleMatmulScheduler::swizzleBlockTiles(
}
}

int factor = std::max(1, params_->grid_swizzle_factor); // must be >=1
int factor =
std::max(1, params_->grid_traversal_factor.first); // must be >=1
switch (params_->cta_order) {
case MatmulParams::TileRasterizationOrder::RowMajor:
// split [I1, I2/factor, factor]
Expand Down Expand Up @@ -327,7 +331,7 @@ std::vector<std::vector<MatmulDimRole>> HopperMultipleMatmulScheduler::
// scheduling is the next step in this modernization.
mma_utils::makeTile(tv, params_->tile_sizes.cta_tile, merged_roles);

swizzleBlockTiles(tv, merged_roles);
reorderBlockTileTraversal(tv, merged_roles);

all_merged_roles.push_back(merged_roles);

Expand Down
5 changes: 3 additions & 2 deletions csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
//! This updates outer_dim_roles if we introduce a new dimension, which can
//! happen if tv is missing a merged axis, in which case we skip merging after
//! the split. This is analogous to forwarding during transform propagation.
void swizzleBlockTiles(
void reorderBlockTileTraversal(
TensorView* tv,
std::vector<MatmulDimRole>& outer_dim_roles);

Expand All @@ -138,7 +138,8 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
//! with the same role will be merged.
//! 2) After that, we perform splits according to
//! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki].
//! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has
//! 3) Depending on the value of params_->grid_traversal_factor, if the TV
//! has
//! both M and N dimensions, we perform a 2D swizzle of the outer dimensions
//! Mo and No.
//! 4) Finally, we do a split-K split if the splitk_factor is not 1
Expand Down
13 changes: 7 additions & 6 deletions csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,12 @@ class MatmulParams : public HeuristicParams {
//! will more likely be forming sub-tiles of the C matrix. This will increase
//! L2 hit rate/data reuse of A and B.
//!
//! Eg for grid_swizzle_factor=2:
//! Eg for grid_traversal_factor = {2, 1}:
//! A1 A2 B1 B2 --> A1 A2 A3 A4 B1 B2 B3 B4
//! A3 A4 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4
//! C1 C2 D1 D2
//! C3 C4 D3 D4
int grid_swizzle_factor = 1;
std::pair<int, int> grid_traversal_factor = {1, 1};

//! Unswizzle MMA results in shared memory to get
//! coalesced write to global memory
Expand Down Expand Up @@ -370,7 +370,7 @@ class MatmulParams : public HeuristicParams {
<< ((cta_order == TileRasterizationOrder::RowMajor) ? "row-major"
: "column-major")
<< "\n"
<< "Grid swizzle factor: " << grid_swizzle_factor << "\n";
<< "Grid swizzle factor: " << grid_traversal_factor << "\n";
ss << "Tiling strategy: ";
switch (tiling_strategy) {
case TilingStrategy::OneTilePerCTA:
Expand Down Expand Up @@ -425,8 +425,9 @@ class MatmulParams : public HeuristicParams {
(circular_buffer_options.hash() << 2) ^
(nvfuser::hash(tile_sizes) << 3) ^
(std::hash<size_t>{}(static_cast<size_t>(cta_order)) << 4) ^
(std::hash<size_t>{}(grid_swizzle_factor) << 5) ^
(std::hash<size_t>{}(splitk_factor) << 6);
(std::hash<size_t>{}(grid_traversal_factor.first) << 5) ^
(std::hash<size_t>{}(grid_traversal_factor.second) << 6) ^
(std::hash<size_t>{}(splitk_factor) << 7);
return attr_hash;
}

Expand All @@ -442,7 +443,7 @@ class MatmulParams : public HeuristicParams {
other->circular_buffer_options == circular_buffer_options &&
other->supported_vec_size == supported_vec_size &&
other->cta_order == cta_order &&
other->grid_swizzle_factor == grid_swizzle_factor &&
other->grid_traversal_factor == grid_traversal_factor &&
other->use_smem_epilogue == use_smem_epilogue &&
other->promote_prologue_smem_reuse == promote_prologue_smem_reuse &&
other->splitk_factor == splitk_factor;
Expand Down
4 changes: 2 additions & 2 deletions csrc/scheduler/matmul_heuristic_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) {
config->cluster_dims[1] = mparams->cluster_dims.y;
config->cluster_dims[2] = mparams->cluster_dims.z;
config->splitk_factor = mparams->splitk_factor;
config->grid_swizzle_factor = mparams->grid_swizzle_factor;
config->grid_swizzle_factor = mparams->grid_traversal_factor.first;
config->cta_order =
mparams->cta_order == MatmulParams::TileRasterizationOrder::RowMajor ? 0
: 1;
Expand Down Expand Up @@ -179,7 +179,7 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) {
menc.k = config->instruction_tile[2];
mparams->mma_macro = menc; // cast back to uint64_t
mparams->splitk_factor = config->splitk_factor;
mparams->grid_swizzle_factor = config->grid_swizzle_factor;
mparams->grid_traversal_factor.first = config->grid_swizzle_factor;
switch (config->cta_order) {
case 0:
mparams->cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
Expand Down
11 changes: 5 additions & 6 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,14 @@ bool fillDefaultHopperHeuristic(

// We also swizzle the tiles as much as possible up to 16 tiles. Like choosing
// the rasterization order, this is used to increase L2 locality
mparams->grid_swizzle_factor = std::min(swizzled_tiles, 16L);
while (swizzled_tiles % mparams->grid_swizzle_factor != 0) {
int grid_traversal_factor = std::min(swizzled_tiles, 16L);
while (swizzled_tiles % grid_traversal_factor != 0) {
// Decrease the swizzle factor if it would result in nondivisible splits,
// since this would unnecessarily increase the grid size.
mparams->grid_swizzle_factor--;
grid_traversal_factor--;
}
// TODO: grid swizzling is currently disabled on Hopper since we cannot
// properly inline when we swizzle unmapped loop broadcasts
mparams->grid_swizzle_factor = 1L;
// TODO: Use only 1D grid traversal factor for now
mparams->grid_traversal_factor = {grid_traversal_factor, 1};

// TODO: Finally, we set the CGA size

Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ TEST_P(MatmulTestWithLayout, AmpereSwizzle) {
mparams.circular_buffer_options.smem_circular_buffer_stage = 3;

mparams.cta_order = order;
mparams.grid_swizzle_factor = swizzle;
mparams.grid_traversal_factor = {swizzle, 1};

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);
Expand Down Expand Up @@ -5016,7 +5016,7 @@ TEST_F(HopperMatmulTest, MLPGemmPersistentBroadcastInputs) {
MatmulParams::TilingStrategy::DistributeTilesAcrossSMs;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.grid_swizzle_factor = 8;
mparams.grid_traversal_factor = {8, 1};
// TODO reduced share memory aliasing because of persistent scheduling
mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
Expand Down