From 49dcc77fdde0a442ff3a5bdba34d2f54110a962e Mon Sep 17 00:00:00 2001 From: Michel Migdal Date: Wed, 29 Mar 2023 03:02:11 -0700 Subject: [PATCH] CTA Swizzles (#87) Adds a CTA swizzle to change the order in which the tiles of the output matrix are processed. This swizzle increases data reuse from A and B, when iterating over gridDim.x. Turns out that CTAs are launched in practice by iterating over gridDim.x first (order is unspecified though, it just happens to behave the same). As a result, the current wave will contain CTAs that compute square sub-matrices of C, and so, increase L2 hit rate. Best factor seems to be 4. This will be part of the heuristics. Setting the factor to 1 disables this swizzle. On a 8192x8192x8192 matmul with default config, the speedup is about 20%. An extreme example is following case: `MNK = 6144 6144 6144, layout=NT stages=0, cta_tile = 32 32 128, warp_tile = 16 16 128, instruction_tile = 16 16 16` where the runtime drops from 12.4 ms to 7.28ms ! Thank you @zasdfgbnm for the help. Values measured on NVIDIA A100 SXM4 80 GB --------- Co-authored-by: Gao, Xiang --- csrc/scheduler/matmul.cpp | 23 +++++++++ csrc/scheduler/matmul.h | 15 ++++++ test/test_gpu_tensorcore.cpp | 95 ++++++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 8b479e5d9a4..fcca8504846 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -181,6 +181,29 @@ void scheduleMatmul( // [... M,N,K] scheduler_utils::matmul_utils::makeTile(cc, gemm_tile.cta_tile.toVector()); + // Applies swizzle factor on C + if (params.grid_swizzle_factor != 1) { + int factor = std::max(1, params.grid_swizzle_factor); // must be >=1 + if (params.rasterization_order == + MatmulParam::TileRasterizationOrder::RowMajor) { + cc->split(1, factor); + // [I1, I2/factor, factor] + cc->reorder({{1, 2}}); + // [I1, factor, I2/factor] + cc->merge(0); + // [I1*factor, I2/factor] + } else if ( + params.rasterization_order == + MatmulParam::TileRasterizationOrder::ColumnMajor) { + cc->split(0, factor); + // [I1/factor, factor, I2] + cc->reorder({{1, 2}}); + // [I1/factor, I2, factor] + cc->merge(1); + // [I1/factor, I2*factor] + } + } + // [Mo, No, Ko, Mi, Ni, Ki] // Propagate tiling globally scheduler_utils::transformPropagateToAllFrom(cc, -1); diff --git a/csrc/scheduler/matmul.h b/csrc/scheduler/matmul.h index 7595bea8ee3..e2e50e4333b 100644 --- a/csrc/scheduler/matmul.h +++ b/csrc/scheduler/matmul.h @@ -48,6 +48,21 @@ class MatmulParam { RowMajor = 0, ColumnMajor = 1 } rasterization_order = TileRasterizationOrder::RowMajor; + + //! Swizzle factor is used to increase L2 hit rate. + //! It horizontally squeezes the grid so that gridDim.x is larger and + //! gridDim.y is smaller. + //! We rely on the observation that the CTAs are scheduled by the GPU by + //! iterating on gridDim.x first. As a result, as blocks are launched, they + //! will more likely be forming sub-tiles of the C matrix. This will increase + //! L2 hit rate/data reuse of A and B. + //! + //! Eg for grid_swizzle_factor=2: + //! A1 A2 B1 B2 --> A1 A2 A3 A4 B1 B2 B3 B4 + //! A3 A4 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4 + //! C1 C2 D1 D2 + //! C3 C4 D3 D4 + int grid_swizzle_factor = 1; }; //! Prototype auto scheduling function. diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 061ee418a93..94da9900b67 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -729,6 +729,101 @@ TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { } } +// Matmul test for Ampere MMA: checking CTA Swizzles +TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int dim = 8192; + int M = dim, N = dim, K = dim; + const auto all_orders = { + MatmulParam::TileRasterizationOrder::RowMajor, + MatmulParam::TileRasterizationOrder::ColumnMajor}; + + REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); + + auto test = [&](MatmulLayout layout, + MatmulParam::TileRasterizationOrder order, + int swizzle, + float& runtime) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 3; + + params.rasterization_order = order; + params.grid_swizzle_factor = swizzle; + + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + FusionExecutor fe; + fe.setMeasureKernelTimeFlag(true); + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams)); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.01, 0.01)); + + int gdimx = fe.lastLaunchParams().gdimx(); + int gdimy = fe.lastLaunchParams().gdimy(); + + int expected_gdim_unswizzled = (dim + 128 - 1) / 128; + int expected_gdimx = expected_gdim_unswizzled * swizzle; + int expected_gdimy = (expected_gdim_unswizzled + swizzle - 1) / swizzle; + + TORCH_CHECK(gdimx == expected_gdimx); + TORCH_CHECK(gdimy == expected_gdimy); + + runtime = fe.kernelTimeMs(); + }; + + // Checking only a single layout to keep runtime short (compilation overhead) + for (auto layout : {MatmulLayout::TT}) { + for (auto order : all_orders) { + float runtime1 = 0; + test(layout, order, 1, runtime1); + + float runtime4 = 0; + test(layout, order, 4, runtime4); + + // GRID Swizzle requires further changes to work in main. So for now we + // don't assert the perf benefit here. + // TORCH_CHECK(runtime4 < runtime1); + } + } +} + TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248;