diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 8b479e5d9a4..fcca8504846 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -181,6 +181,29 @@ void scheduleMatmul( // [... M,N,K] scheduler_utils::matmul_utils::makeTile(cc, gemm_tile.cta_tile.toVector()); + // Applies swizzle factor on C + if (params.grid_swizzle_factor != 1) { + int factor = std::max(1, params.grid_swizzle_factor); // must be >=1 + if (params.rasterization_order == + MatmulParam::TileRasterizationOrder::RowMajor) { + cc->split(1, factor); + // [I1, I2/factor, factor] + cc->reorder({{1, 2}}); + // [I1, factor, I2/factor] + cc->merge(0); + // [I1*factor, I2/factor] + } else if ( + params.rasterization_order == + MatmulParam::TileRasterizationOrder::ColumnMajor) { + cc->split(0, factor); + // [I1/factor, factor, I2] + cc->reorder({{1, 2}}); + // [I1/factor, I2, factor] + cc->merge(1); + // [I1/factor, I2*factor] + } + } + // [Mo, No, Ko, Mi, Ni, Ki] // Propagate tiling globally scheduler_utils::transformPropagateToAllFrom(cc, -1); diff --git a/csrc/scheduler/matmul.h b/csrc/scheduler/matmul.h index 7595bea8ee3..e2e50e4333b 100644 --- a/csrc/scheduler/matmul.h +++ b/csrc/scheduler/matmul.h @@ -48,6 +48,21 @@ class MatmulParam { RowMajor = 0, ColumnMajor = 1 } rasterization_order = TileRasterizationOrder::RowMajor; + + //! Swizzle factor is used to increase L2 hit rate. + //! It horizontally squeezes the grid so that gridDim.x is larger and + //! gridDim.y is smaller. + //! We rely on the observation that the CTAs are scheduled by the GPU by + //! iterating on gridDim.x first. As a result, as blocks are launched, they + //! will more likely be forming sub-tiles of the C matrix. This will increase + //! L2 hit rate/data reuse of A and B. + //! + //! Eg for grid_swizzle_factor=2: + //! A1 A2 B1 B2 --> A1 A2 A3 A4 B1 B2 B3 B4 + //! A3 A4 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4 + //! C1 C2 D1 D2 + //! C3 C4 D3 D4 + int grid_swizzle_factor = 1; }; //! Prototype auto scheduling function. diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 061ee418a93..94da9900b67 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -729,6 +729,101 @@ TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { } } +// Matmul test for Ampere MMA: checking CTA Swizzles +TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int dim = 8192; + int M = dim, N = dim, K = dim; + const auto all_orders = { + MatmulParam::TileRasterizationOrder::RowMajor, + MatmulParam::TileRasterizationOrder::ColumnMajor}; + + REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); + + auto test = [&](MatmulLayout layout, + MatmulParam::TileRasterizationOrder order, + int swizzle, + float& runtime) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 3; + + params.rasterization_order = order; + params.grid_swizzle_factor = swizzle; + + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + FusionExecutor fe; + fe.setMeasureKernelTimeFlag(true); + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams)); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.01, 0.01)); + + int gdimx = fe.lastLaunchParams().gdimx(); + int gdimy = fe.lastLaunchParams().gdimy(); + + int expected_gdim_unswizzled = (dim + 128 - 1) / 128; + int expected_gdimx = expected_gdim_unswizzled * swizzle; + int expected_gdimy = (expected_gdim_unswizzled + swizzle - 1) / swizzle; + + TORCH_CHECK(gdimx == expected_gdimx); + TORCH_CHECK(gdimy == expected_gdimy); + + runtime = fe.kernelTimeMs(); + }; + + // Checking only a single layout to keep runtime short (compilation overhead) + for (auto layout : {MatmulLayout::TT}) { + for (auto order : all_orders) { + float runtime1 = 0; + test(layout, order, 1, runtime1); + + float runtime4 = 0; + test(layout, order, 4, runtime4); + + // GRID Swizzle requires further changes to work in main. So for now we + // don't assert the perf benefit here. + // TORCH_CHECK(runtime4 < runtime1); + } + } +} + TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248;