diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 8ab8283ff51..1f42f540086 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -57,54 +57,83 @@ void moveInnerBroadcastLeft(TensorView* tv, int number_of_inner_pos = 3) { tv->reorder(order_map); } +// Utility to check concrete static size: +inline void checkConcreteStaticDim(IterDomain* id) { + TORCH_INTERNAL_ASSERT( + !id->isBroadcast() && !id->isReduction(), + "no support for reduction or broadcast domains, but got ", + id->toString()); + TORCH_INTERNAL_ASSERT( + id->extent()->isConstInt(), + "swizzled dimension's extend must be known during scheduling, got ", + id->toString()); +} + //! Automatically generates the shared memory swizzled data layout -//! for matmul mainloop. -//! The shared mem datalayout is always 2D currently, and this utility -//! function assumes that the innermost 2 dimensions on shared_mem_tv -//! are the ones begin swizzled. -void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) { +//! for matmul mainloop and epilogue. +//! The shared mem data layout is always 2D currently, and this utility +//! function assumes that the shared_mem_tv has the following structure: +//! [tile_row, tile_col, ***skip***] where the parameter `skip` is the number +//! of reduction domains to be skipped. The IDs of tile_row and tile_col are +//! the ones being swizzled. +//! If the input tensorview is not stored in shared memory, the function will +//! skip the actual swizzle. This is used to help the domain mapping between +//! mma_result and the epilogue tensor. +void swizzleSharedMemory( + TensorView* shared_mem_tv, + const MatmulParams& params) { + // Set skip to skip all consecutive reduction domains starting from the + // innermost dimension. + int skip = 0; + for (int i = (int)shared_mem_tv->nDims() - 1; i >= 0; --i) { + if (shared_mem_tv->axis(i)->isReduction()) { + skip++; + } else { + break; + } + } + // Check that the innermost 2 dimensions are concrete and static // sized so that the swizzle function can be defined. - - // Utility to check concrete static size: - auto check_concrete_static_dim = [](IterDomain* id) { - TORCH_INTERNAL_ASSERT( - !id->isBroadcast() && !id->isReduction(), - "no support on reduction or broadcast dims, but get ", - id->toString()); - TORCH_INTERNAL_ASSERT( - id->extent()->isConstInt(), - "swizzled dimensions need to be statically, but get ", - id->toString()); - }; - TORCH_INTERNAL_ASSERT( - shared_mem_tv->nDims() >= 2, - "At least 2D input needed for swizzling, but get ", + shared_mem_tv->nDims() >= (size_t)(2 + skip), + "At least 2D input (excluding consecutive reduction domains starting from the innermost dim) needed for swizzling, but get ", shared_mem_tv->toString()); - check_concrete_static_dim(shared_mem_tv->axis(-2)); - check_concrete_static_dim(shared_mem_tv->axis(-1)); + checkConcreteStaticDim(shared_mem_tv->axis(-2 - skip)); + checkConcreteStaticDim(shared_mem_tv->axis(-1 - skip)); // Extract the constant sizes of the swizzled tile - const auto tile_size_x = shared_mem_tv->axis(-2)->extent()->evaluateInt(); - const auto tile_size_y = shared_mem_tv->axis(-1)->extent()->evaluateInt(); + const int64_t tile_size_x = + shared_mem_tv->axis(-2 - skip)->extent()->evaluateInt(); + const int64_t tile_size_y = + shared_mem_tv->axis(-1 - skip)->extent()->evaluateInt(); if (isTuring(params.mma_macro) || isAmpere(params.mma_macro)) { - // TODO: right now, we are assuming ldmatrix access, which only supports - // sizeof(T) == 16bit (i.e. half/bfloat16) load according to offical doc: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-load-instruction-ldmatrix - // In the future, when we start adding support for tf32(different macro), - // fp32(ffma), double, int8, fp8, etc. we need to update this function. - TORCH_INTERNAL_ASSERT(dataTypeSize(*shared_mem_tv->getDataType()) == 2); - - // ldmatrix loads a ldmatrix_rows x ldmatrix_cols = 8 x 8 matrix each time, - constexpr int64_t ldmatrix_rows = 8; - constexpr int64_t ldmatrix_cols = 8; + // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. + // half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit + // (i.e. float) + const int64_t data_type_size = + (int64_t)dataTypeSize(*shared_mem_tv->getDataType()); + TORCH_INTERNAL_ASSERT(data_type_size == 2 || data_type_size == 4); + + // For main loop, ldmatrix loads a n_rows x n_cols = 8 x 8 matrix each time. + // For epilogue, threads in a warp is organized as 8 rows x 4 columns. + // Each thread vectorized write 2 items, so 8 items per row. + //--0--1--2--3 + //--4--5--6--7 + //--8--9--10-11 + //--12-13-14-15 + //--16-17-18-19 + //--20-21-22-23 + //--24-25-26-27 + //--28-29-30-31 + constexpr int64_t n_rows = 8; + constexpr int64_t n_cols = 8; // Column size of the tile needs to be multiples of 8 for ldmatrix to work. TORCH_INTERNAL_ASSERT( - tile_size_x >= ldmatrix_rows && tile_size_x % ldmatrix_rows == 0 && - tile_size_y >= ldmatrix_cols && tile_size_y % ldmatrix_cols == 0, + tile_size_x >= n_rows && tile_size_x % n_rows == 0 && + tile_size_y >= n_cols && tile_size_y % n_cols == 0, "Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle", tile_size_x, "x", @@ -148,11 +177,10 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) { * has 8 rows, and each row has exactly one unit. */ - constexpr int64_t items_per_unit = ldmatrix_cols; - constexpr int64_t bytes_per_unit = - items_per_unit * primDataTypeSize(DataType::Half); - constexpr int64_t words_per_unit = bytes_per_unit / smem_bytes_per_word; - constexpr int64_t num_megabanks = smem_banks / words_per_unit; + constexpr int64_t items_per_unit = n_cols; + const int64_t bytes_per_unit = items_per_unit * data_type_size; + const int64_t words_per_unit = bytes_per_unit / smem_bytes_per_word; + const int64_t num_megabanks = smem_banks / words_per_unit; /* In the following example, each CTA tile contains 2 rows and 3 colums of * matrices, each 8x8 size: @@ -172,7 +200,7 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) { /* So the bank conflicting problem is now converted to the following game: * I have a clock that has one pointer and `num_megabanks` ticks. I start * my game by making my pointer pointing to somewhere, and turn forward - * the pointer `ldmatrix_rows` times, each time by `row_stride` ticks. + * the pointer `n_rows` times, each time by `row_stride` ticks. * This problem can be well modeled by modular arithmetic in number theory * using the concept "integers modulo n" a.k.a. "Z/nZ"[1]. * Take n = 6 as an example, Z/6Z only has 6 elements: 0, 1, 2, 3, 4, 5. @@ -199,7 +227,6 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) { // assert(row_stride >= 0); // assert(num_megabanks >= 0); int64_t row_stride_znz = row_stride % num_megabanks; - /* Consider the following function in Z/nZ: * f(i; init) = init + i * stride * where init is the initial position of the pointer in the clock when we @@ -290,7 +317,7 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) { int64_t repeated_pattern_size = num_megabanks / g; - if (repeated_pattern_size >= ldmatrix_rows) { + if (repeated_pattern_size >= n_rows) { return; // No need to swizzle in this case. } @@ -357,18 +384,20 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) { */ TORCH_INTERNAL_ASSERT( - ldmatrix_rows % repeated_pattern_size == 0, + n_rows % repeated_pattern_size == 0, "Can not partition matrix into megarows"); - int64_t num_gigarows = ldmatrix_rows / repeated_pattern_size; + int64_t num_gigarows = n_rows / repeated_pattern_size; int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size // -2 -1 // [row, col] - shared_mem_tv->split(-2, repeated_pattern_size); - shared_mem_tv->split(-1, ldmatrix_cols); + if (repeated_pattern_size > 1) { + shared_mem_tv->split(-2 - skip, repeated_pattern_size); + } + shared_mem_tv->split(-1 - skip, n_cols); // -4 -3 -2 -1 // [gigarow id, gigarow, matrix id, matrix] - shared_mem_tv->split(-2, num_gigabanks); + shared_mem_tv->split(-2 - skip, num_gigabanks); // -5 -4 -3 -2 -1 // [gigarow id, gigarow, y outer, gigabank id, matrix] // Note that megabanks inside a gigabank are not contiguous, so the gigabank @@ -418,22 +447,49 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) { // -5 -4 -3 -2 -1 // [gigarow id, gigarow, y outer, gigabank id, matrix] - shared_mem_tv->split(-5, num_gigabanks); + int axis_of_gigarow_id = repeated_pattern_size > 1 ? -5 : -4; + shared_mem_tv->split(axis_of_gigarow_id - skip, num_gigabanks); // -6 -5 -4 -3 -2 -1 // [wave id, wave, gigarow, y outer, gigabank id, matrix] - if (isPowOf2(num_gigabanks)) { - shared_mem_tv->swizzle(Swizzle2DType::XOR, -5, -2); - } else { - shared_mem_tv->swizzle(Swizzle2DType::CyclicShift, -5, -2); + // swizzle wave with gigabank id to make threads in a wave access different + // gigabank. Apply swizzle only when shared_mem_tv is stored in shared + // memory. + // TODO: This is a temporary workaround for the following issue: + // For the mma output, we have the following schedule: + // rFactor: [...., X, Y] -> mma-swizzle transformations -> leaf + // For epilogue smem tensor, the schedule is + // rFactor: [...., X, Y] -> split -> [...., X1, X2, X3, Y1, Y2, Y3] + // -> swizzle X2, Y2 -> [...., X1, X2', X3, Y1, Y2', Y3] + // -> merge back -> [...., X', Y'] + // -> mma-swizzle transformations -> leaf + // The mma-swizzle transformations for the mma output and epilogue smem + // tensor are the same. In indexing, we do require {X, X'} and {Y, Y'} to be + // mapped in CA map, however, we currently can not handle that. So we have + // to do the same split and merge to the mma output without actually + // applying the swizzle, and this check is to detect and handle this + // specific case. We should remove this special handling when we fix our CA + // mapping. + if (shared_mem_tv->getMemoryType() == MemoryType::Shared) { + int axis_of_gigarow_id = repeated_pattern_size > 1 ? -5 : -4; + if (isPowOf2(num_gigabanks)) { + shared_mem_tv->swizzle( + Swizzle2DType::XOR, axis_of_gigarow_id - skip, -2 - skip); + } else { + shared_mem_tv->swizzle( + Swizzle2DType::CyclicShift, axis_of_gigarow_id - skip, -2 - skip); + } + } + + if (repeated_pattern_size > 1) { + shared_mem_tv->merge(-6 - skip); } + shared_mem_tv->merge(-5 - skip); + + // merge back tile_size_y + shared_mem_tv->merge(-3 - skip); + shared_mem_tv->merge(-2 - skip); - // Merge back the tile for subsequent vectorization scheduling - // TODO: could potentially simplify away the merges - shared_mem_tv->merge(-6); - shared_mem_tv->merge(-5); - shared_mem_tv->merge(-3); - shared_mem_tv->merge(-2); } else if (isVolta(params.mma_macro)) { // TODO: Volta is slightly more complex, and a fixed recipe would // not scale. In a follow up this would be inferred from the mma @@ -456,8 +512,7 @@ void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) { mma_utils::orderTiledConcreteIdAsRoot(shared_mem_tv); // Swizzle the shared memory data layout - prologSwizzle(shared_mem_tv, params); - + swizzleSharedMemory(shared_mem_tv, params); // Assuming we are always vectorizing smem write by 128b at the moment: // TODO: would need a data-type and alignment dependent interface // to support non-vectorizable shapes. @@ -477,6 +532,79 @@ void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) { .propagateParallelType()); } +void scheduleOutputTensor( + TensorView* mma_result, + TensorView* c, + const MatMulTileOptions& gemm_tile) { + // input tensor is in the form of [Mo,No,cta_tile_m,cta_tile_n] + checkConcreteStaticDim(c->axis(-2)); + checkConcreteStaticDim(c->axis(-1)); + const int64_t tile_size_m = c->axis(-2)->extent()->evaluateInt(); + const int64_t tile_size_n = c->axis(-1)->extent()->evaluateInt(); + TORCH_INTERNAL_ASSERT( + tile_size_m == gemm_tile.cta_tile.m, + "Actual tile size at axis(-2) in output tensor is different from CTA tile size! Expected: ", + gemm_tile.cta_tile.m, + ", actual: ", + tile_size_m); + TORCH_INTERNAL_ASSERT( + tile_size_n == gemm_tile.cta_tile.n, + "Actual tile size at axis(-1) in output tensor is different from CTA tile size! Expected: ", + gemm_tile.cta_tile.n, + ", actual: ", + tile_size_n); + const int64_t tot_elements = tile_size_m * tile_size_n; + const int64_t data_type_size = (int64_t)dataTypeSize(*c->getDataType()); + constexpr int64_t warp_size = 32l; + const int64_t vectorization_factor = 16l / data_type_size; + const int64_t tidx = warp_size; + const int64_t tidy = gemm_tile.cta_tile.n / gemm_tile.warp_tile.n; + const int64_t tidz = gemm_tile.cta_tile.m / gemm_tile.warp_tile.m; + // step-1, merge last 2 dims + c->merge(-2); + // [Mo, No, m*n] + + // step-2, set vectorization to maximum + // We have fixed tidx, tidy, and tidz, so we need to make sure that the output + // tensor is divisible by tidx * tidy * tidz * vectorization_factor + TORCH_INTERNAL_ASSERT( + tot_elements % (tidx * tidy * tidz * vectorization_factor) == 0, + "Output tensor cannot be fully vectorized! tot_elements:", + tot_elements, + ", tidx: ", + tidx, + ", tidy: ", + tidy, + ", tidz: ", + tidz, + ", vectorization_factor: ", + vectorization_factor); + c->split(-1, vectorization_factor); + c->axis(-1)->parallelize(ParallelType::Vectorize); + // [Mo, No, m*n/vect, vect] + + // step-3, Split out a warp for TIDx + c->split(-2, tidx); + c->axis(-2)->parallelize(ParallelType::TIDx); + // [Mo, No, m*n/vect/TIDx, TIDx, vect] + + // step-4, Split out for TIDy and TIDz + // TIDy = cta_tile_n/warp_tile_n + // TIDz = cta_tile_m/warp_tile_m + c->split(-3, tidy); + c->axis(-3)->parallelize(ParallelType::TIDy); + + c->split(-4, tidz); + c->axis(-4)->parallelize(ParallelType::TIDz); + // [Mo, No, m*n/vect/TIDx/TIDy/TIDz, TIDz, TIDy, TIDx, vect] + + // step-5, Parallel first 2 dims same as mma_result + scheduler_utils::parallelizeAllLike( + mma_result, + 2, + {c}, + {ParallelType::BIDx, ParallelType::BIDy, ParallelType::BIDz}); +} //! Propagates transformations from fusion output to fusion tv inputs that are //! producers in the epilogue. Transformations' propagation aims at input tvs //! which are not assigned to core roles, that is, are not MMA inputs. @@ -604,6 +732,11 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // Mma object is valid only because cacheBefore has been done on // TV which is not output of MmaOp, as there is an epilogue auto mma_result = has_epilogue ? mma->out()->as() : dc; + + // Unswizzle mma result in shared memory + auto smem_epilogue = + params.use_smem_epilogue ? mma_result->cacheAfter() : mma_result; + // Clear MmaOp pointer, it's not needed from now on mma = nullptr; @@ -732,12 +865,21 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // Propagate tiling globally scheduler_utils::transformPropagateToAllFrom(mma_result, -1); + if (params.use_smem_epilogue) { + // Transform mma_result through the epilogue swizzle without actually + // swizzling the axes. This is done to enable the domains + // are mapped between mma_result and smem_epilogue. + swizzleSharedMemory(mma_result, params); + } + // Schedule warp tile mma_utils::scheduleWarpTileWithReduction(mma_result, gemm_tile); + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Kw Mwo Nwo Mwi Nwi Mi, Ni, Ki] // Propagate warp tile to main loop and epilog/output tvs scheduler_utils::BoundedDirectionalTransformPropagator::bothWays( - mma_result, -1, {acw_smem, bcw_smem}, {d}); + mma_result, -1, {acw_smem, bcw_smem}, {smem_epilogue}); // Schedule prolog: // TODO: this section needs more configurability. @@ -753,7 +895,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { moveInnerBroadcastLeft(ab); moveInnerBroadcastLeft(bb); } - ab->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); bb->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); @@ -812,16 +953,36 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { {acr, bcr, ab, bb}, {ParallelType::TIDy, ParallelType::TIDz}); - scheduler_utils::BoundedDirectionalTransformPropagator::forward( - mma_result, - -1, - {d}, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType() - .propagateToBoundary()); - - d->axis(-1)->parallelize(ParallelType::Vectorize); - + if (params.use_smem_epilogue) { + smem_epilogue->setMemoryType(MemoryType::Shared); + swizzleSharedMemory(smem_epilogue, params); + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + mma_result, + -1, + {smem_epilogue}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); + smem_epilogue->axis(-1)->parallelize(ParallelType::Vectorize); + + // Schedule output tensor differently for better global memory access + // pattern. + scheduleOutputTensor(mma_result, d, gemm_tile); + d->axis(-1)->parallelize(ParallelType::Vectorize); + + // Propagate output tensor transformations back to smem_epilogue + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + d, -1, {smem_epilogue}); + } else { + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + mma_result, + -1, + {d}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); + d->axis(-1)->parallelize(ParallelType::Vectorize); + } // propagate output transformations to all inputs that are part of epilogue // operations, input tvs with non-core roles // core roles: essential for matmul, for example mma inputs' producers diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index dc9a03cdf32..14fea9db17c 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -90,6 +90,10 @@ class MatmulParams : public HeuristicParams { //! C3 C4 D3 D4 int grid_swizzle_factor = 1; + //! Unswizzle MMA results in shared memory to get + //! coalesced write to global memory + bool use_smem_epilogue = false; + std::string toString() const override { std::stringstream ss; ss << "\n===== Matmul Parameters ========\n" @@ -112,6 +116,7 @@ class MatmulParams : public HeuristicParams { : "column-major") << "\n" << "Grid swizzle factor: " << grid_swizzle_factor << "\n" + << "Use shared memory epilogue: " << use_smem_epilogue << "\n" << "====================================\n"; return ss.str(); } diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index b15f6562490..19cafd8725a 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -150,6 +150,23 @@ inline bool initExtraHeuristics( return true; } +//! A wrapper to get MMA Tensor data types +//! The order of returned types: INPUT_A, INPUT_B, OUTPUT_D +inline mma_utils::MmaDataTypes getMmaDataTypes( + const std::map>& roles_map) { + auto getMMADataType = [&](MatmulRole role) { + auto entry = roles_map.find(role); + if (entry != roles_map.end() && !entry->second.empty()) { + return entry->second.front()->dtype(); + } + TORCH_INTERNAL_ASSERT(false, "Get MMA Tensor data type failed!"); + }; + const auto a_type = getMMADataType(MatmulRole::INPUT_A); + const auto b_type = getMMADataType(MatmulRole::INPUT_B); + const auto c_type = getMMADataType(MatmulRole::OUTPUT_D); + return mma_utils::MmaDataTypes{a_type, b_type, c_type}; +} + //! A helper for getting problem shape from fusion and runtime info. ProblemShape getProblemShape( Fusion* fusion, @@ -398,6 +415,15 @@ std::shared_ptr getMatmulHeuristics( // Disable magic zero for matmul kernels params->cparams.enable_magic_zero = false; + // Set whether to use shared memory for epilogue + const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion); + TORCH_INTERNAL_ASSERT( + roles_map_opt.isValid(), "Tensor roles map in mma is not valid."); + params->use_smem_epilogue = mma_utils::generateSharedMemoryEpilogueHeuristics( + params->tile_sizes, + params->double_buffer_options.smem_double_buffer_stage, + getMmaDataTypes(roles_map_opt.getData())); + if (isDebugDumpEnabled(DebugDumpOption::MatmulChecks)) { printMsg(params->toString()); } diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 3209c4329b6..808256f7f43 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -6,6 +6,7 @@ */ // clang-format on +#include #include #include #include @@ -14,11 +15,49 @@ #include #include #include "mma_type.h" - namespace nvfuser { namespace mma_utils { +bool generateSharedMemoryEpilogueHeuristics( + const MatMulTileOptions& gemm_tile, + const int smem_double_buffer_stage, + const MmaDataTypes& data_types) { + const auto properties = at::cuda::getCurrentDeviceProperties(); + const size_t device_smem_limit = properties->sharedMemPerBlockOptin; + + auto warp_dims = gemm_tile.cta_tile / gemm_tile.warp_tile; + const auto threads_per_block = + warp_dims.m * warp_dims.n * warp_dims.k * properties->warpSize; + // a thread can use up to 255 registers, blocks per sm is limited by available + // registers + const auto threads_per_sm = getThreadsPerSMGivenRegPerThread(255); + const auto blocks_per_sm_by_register = threads_per_sm / threads_per_block; + // see scheduleContiguousVectorLoad + const int vector_word = 8; + const int round_to_factor = warp_dims.m * warp_dims.n * warp_dims.k * + properties->warpSize * vector_word; + const int mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k; + const int nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k; + const size_t smem_a = (size_t)(ceilDiv(mk, round_to_factor) * + round_to_factor * smem_double_buffer_stage) * + dataTypeSize(data_types[0]); + const size_t smem_b = (size_t)(ceilDiv(nk, round_to_factor) * + round_to_factor * smem_double_buffer_stage) * + dataTypeSize(data_types[1]); + const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * + dataTypeSize(data_types[2]); + + // use additional shared memory for epilogue if blocks per sm is not changed + const auto blocks_per_sm_without_smem_epilogue = std::min( + device_smem_limit / (smem_a + smem_b), (size_t)blocks_per_sm_by_register); + const auto blocks_per_sm_with_smem_epilogue = std::min( + device_smem_limit / (smem_a + smem_b + smem_c), + (size_t)blocks_per_sm_by_register); + return blocks_per_sm_with_smem_epilogue == + blocks_per_sm_without_smem_epilogue; +} + void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { // Assumes // [M, N, K] @@ -379,11 +418,6 @@ bool canValidateIsInnerDim( if (!split->factor()->isConstInt()) { return false; } - if (split->factor()->evaluateInt() < inner_dim_size) { - // This might be too restrictive. Would need more - // bookkeeping to relax. - return false; - } leaf = split->in(); } else if (auto merge = dynamic_cast(expr)) { // Might consider just rejecting merge. @@ -396,9 +430,6 @@ bool canValidateIsInnerDim( if (!leaf->extent()->isConstInt()) { return false; } - if (leaf->extent()->evaluateInt() != inner_dim_size) { - return false; - } leaf = merge->inner(); } else { // No support for swizzled inner dim for now. @@ -438,7 +469,9 @@ void checkDimSize( ":", id->extent()->evaluateInt(), "vs", - expect[axis_index]); + expect[axis_index], + "\n for tv: ", + tv->toString()); } } diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 5372877a15d..be6de628277 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -226,6 +226,10 @@ using ProblemIterDomains = std::array; //! a single tv, for example input for beta scaling in epilogue using RolesMap = std::map>; +//! An alias for storing data types of the tensors in the mma op +//! the order is INPUT_A, INPUT_B, OUTPUT_D +using MmaDataTypes = std::array; + //! A wrapper for data containers with optional error message stored if //! initialization of the data fails. template @@ -289,6 +293,15 @@ TORCH_CUDA_CU_API ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion); //! be gathered. TORCH_CUDA_CU_API RolesMapOpt getTensorsRoles(Fusion* fusion); +//! Return whether use shared memory epilogue or not. +//! Returns true if using shared memory epilogue won't cause +//! the decrease of occupancy ratio. The occupancy ratio is +//! estimated using register and shared memory usage. +TORCH_CUDA_CU_API bool generateSharedMemoryEpilogueHeuristics( + const MatMulTileOptions& gemm_tile, + const int smem_double_buffer_stage, + const MmaDataTypes& data_types); + } // namespace mma_utils } // namespace nvfuser diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 32451068cdf..9c6efd8babc 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -3153,7 +3153,6 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { gemm_tile.cta_tile = GemmTile(128, 128, 64); gemm_tile.warp_tile = GemmTile(64, 64, 64); gemm_tile.instruction_tile = GemmTile(16, 16, 16); - MatmulParams params; params.mma_macro = MmaOptions::MacroType::Ampere_16_16_16; params.tile_sizes = gemm_tile; @@ -3262,6 +3261,11 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; + params.use_smem_epilogue = + mma_utils::generateSharedMemoryEpilogueHeuristics( + gemm_tile, + params.double_buffer_options.smem_double_buffer_stage, + {DataType::Half, DataType::Half, DataType::Float}); scheduleMatmul(&fusion, params); auto inputs = matmulAtInput(M, N, K, layout); @@ -3325,6 +3329,11 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; + params.use_smem_epilogue = + mma_utils::generateSharedMemoryEpilogueHeuristics( + gemm_tile, + params.double_buffer_options.smem_double_buffer_stage, + {DataType::Half, DataType::Half, DataType::Float}); scheduleMatmul(&fusion, params); @@ -3383,7 +3392,11 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; - + params.use_smem_epilogue = + mma_utils::generateSharedMemoryEpilogueHeuristics( + gemm_tile, + params.double_buffer_options.smem_double_buffer_stage, + {DataType::Half, DataType::Half, DataType::Float}); scheduleMatmul(&fusion, params); auto inputs = matmulAtInput(M, N, K, layout); @@ -4467,6 +4480,238 @@ TEST_F(NVFuserTest, FusionAmpereSplitKLikeStridedBatchedMatmul_CUDA) { } } +TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); + // Keep multiples of 8 to keep vectorizable. + int M = 4096, N = 4096, K = 4096; + for (auto layout : kAllSupportedMatmulLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout, true); + + fusion.addOutput(tv2); + + // The settings of cta_tile, warp_tile, and smem_double_buffer_stage have + // been purposefully selected to produce a constant occupancy of 25%. This + // allows us to effectively evaluate the influence of the use_smem_epilogue + // parameter on performance, since changing its value to either true or + // false will not affect the occupancy rate. + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(64, 128, 32); + gemm_tile.warp_tile = GemmTile(32, 32, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + MatmulParams params; + params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.tile_sizes = gemm_tile; + params.use_smem_epilogue = true; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 2; + scheduleMatmul(&fusion, params); + + // If use_smem_epilogue is true, there should be 3 shared memory tensors 2 + // for prologue and 1 for epilogue. + int num_shared_mem_tensors = 0; + int expected_num_shared_mem_tensors = params.use_smem_epilogue ? 3 : 2; + for (const auto& tv : ir_utils::allTvs(&fusion)) { + if (tv->getMemoryType() == MemoryType::Shared) { + num_shared_mem_tensors++; + } + } + TORCH_CHECK( + num_shared_mem_tensors == expected_num_shared_mem_tensors, + "Number of shared memory tensors doesn't match!", + "Expected: ", + expected_num_shared_mem_tensors, + ", Got: ", + num_shared_mem_tensors); + + at::manual_seed(0); + auto inputs = matmulAtInput(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams)); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + + // check bank conflicts + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + // (0.001, 0.001) passed on local A100 but failed on CI A100 + TORCH_CHECK( + cg_outputs[0].allclose(tref, 0.01, 0.01), + "Result validation failed. Max diff: ", + (cg_outputs[0] - tref).abs().max()); + } +} + +TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueCast_CUDA) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); + // Keep multiples of 8 to keep vectorizable. + int M = 4096, N = 4096, K = 4096; + for (auto layout : kAllSupportedMatmulLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout, true); + auto tv3 = castOp(DataType::Half, tv2); + + fusion.addOutput(tv3); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + MatmulParams params; + params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.tile_sizes = gemm_tile; + params.use_smem_epilogue = true; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 4; + scheduleMatmul(&fusion, params); + + // If use_smem_epilogue is true, there should be 3 shared memory tensors 2 + // for prologue and 1 for epilogue. + int num_shared_mem_tensors = 0; + int expected_num_shared_mem_tensors = params.use_smem_epilogue ? 3 : 2; + for (const auto& tv : ir_utils::allTvs(&fusion)) { + if (tv->getMemoryType() == MemoryType::Shared) { + num_shared_mem_tensors++; + } + } + TORCH_CHECK( + num_shared_mem_tensors == expected_num_shared_mem_tensors, + "Number of shared memory tensors doesn't match!", + "Expected: ", + expected_num_shared_mem_tensors, + ", Got: ", + num_shared_mem_tensors); + + at::manual_seed(0); + auto inputs = matmulAtInput(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams)); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + tref = tref.to(at::kHalf); + // check bank conflicts + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + // (0.001, 0.001) passed on local A100 but failed on CI A100 + TORCH_CHECK( + cg_outputs[0].allclose(tref, 0.01, 0.01), + "Result validation failed. Max diff: ", + (cg_outputs[0] - tref).abs().max()); + } +} + +TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueRelu_CUDA) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); + // Keep multiples of 8 to keep vectorizable. + int M = 4096, N = 4096, K = 4096; + for (auto layout : kAllSupportedMatmulLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout, true); + auto tv3 = relu(tv2); + + fusion.addOutput(tv3); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + MatmulParams params; + params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.tile_sizes = gemm_tile; + params.use_smem_epilogue = true; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 4; + scheduleMatmul(&fusion, params); + + // If use_smem_epilogue is true, there should be 3 shared memory tensors 2 + // for prologue and 1 for epilogue. + int num_shared_mem_tensors = 0; + int expected_num_shared_mem_tensors = params.use_smem_epilogue ? 3 : 2; + for (const auto& tv : ir_utils::allTvs(&fusion)) { + if (tv->getMemoryType() == MemoryType::Shared) { + num_shared_mem_tensors++; + } + } + TORCH_CHECK( + num_shared_mem_tensors == expected_num_shared_mem_tensors, + "Number of shared memory tensors doesn't match!", + "Expected: ", + expected_num_shared_mem_tensors, + ", Got: ", + num_shared_mem_tensors); + + at::manual_seed(0); + auto inputs = matmulAtInput(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams)); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto t2 = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + auto tref = at::relu(t2).to(at::kFloat); + + // check bank conflicts + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + // (0.001, 0.001) passed on local A100 but failed on CI A100 + TORCH_CHECK( + cg_outputs[0].allclose(tref, 0.01, 0.01), + "Result validation failed. Max diff: ", + (cg_outputs[0] - tref).abs().max()); + } +} #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace nvfuser diff --git a/test/test_matmul_sass.cpp b/test/test_matmul_sass.cpp index bb54f36eb9c..b6553965b69 100644 --- a/test/test_matmul_sass.cpp +++ b/test/test_matmul_sass.cpp @@ -43,7 +43,8 @@ sass::Container getSASSFor( MmaOptions::MacroType macro, int M, int N, - int K) { + int K, + const bool use_shared_epilogue = false) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -62,6 +63,7 @@ sass::Container getSASSFor( gemm_tile.instruction_tile = instruction_tile; MatmulParams params; + params.use_smem_epilogue = use_shared_epilogue; params.mma_macro = macro; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; @@ -204,110 +206,121 @@ TEST_F(MatmulSASSTest, AmpereModifiers_CUDA) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; - bool found_LDGSTS = false; - bool found_LDSM = false; - bool found_HMMA = false; - bool found_LDGDEPBAR = false; - bool found_BAR = false; - bool found_DEPBAR = false; // kAllSupportedMatmulLayout; - for (auto layout : {MatmulLayout::TT}) { - sass::Container sass; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - sass = getSASSFor( - layout, - GemmTile(128, 128, 32), - GemmTile(64, 64, 32), - GemmTile(16, 8, 16), - MmaOptions::MacroType::Ampere_16_8_16, - M, - N, - K)); - for (auto inst : sass.code) { - std::visit( - [&](auto&& i) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - if (i.opCode() == "LDGSTS") { - const std::vector expect = { - "E", "BYPASS", "LTC128B", "128"}; - TORCH_CHECK( - i.modifiers() == expect, - "Modifiers for LDGSTS has changed. " - "Please manually check if the new modifiers makes sense and update this test. " - "Expect: ", - expect, - " Get: ", - i.modifiers()); - found_LDGSTS = true; - } else if (i.opCode() == "LDGDEPBAR") { - const std::vector expect; - TORCH_CHECK( - i.modifiers() == expect, - "Modifiers for LDGDEPBAR has changed. " - "Please manually check if the new modifiers makes sense and update this test. " - "Expect: ", - expect, - " Get: ", - i.modifiers()); - found_LDGDEPBAR = true; - } else if (i.opCode() == "LDSM") { - const std::vector expect1 = {"16", "M88", "2"}; - const std::vector expect2 = {"16", "M88", "4"}; - const std::vector expect3 = {"16", "MT88", "2"}; - const std::vector expect4 = {"16", "MT88", "4"}; - TORCH_CHECK( - i.modifiers() == expect1 || i.modifiers() == expect2 || - i.modifiers() == expect3 || i.modifiers() == expect4, - "Modifiers for LDGDEPBAR has changed. " - "Please manually check if the new modifiers makes sense and update this test."); - found_LDSM = true; - } else if (i.opCode() == "HMMA") { - const std::vector expect = {"16816", "F32"}; - TORCH_CHECK( - i.modifiers() == expect, - "Modifiers for HMMA has changed. " - "Please manually check if the new modifiers makes sense and update this test. " - "Expect: ", - expect, - " Get: ", - i.modifiers()); - found_HMMA = true; - } else if (i.opCode() == "BAR") { - const std::vector expect = { - "SYNC", "DEFER_BLOCKING"}; - TORCH_CHECK( - i.modifiers() == expect, - "Modifiers for BAR has changed. " - "Please manually check if the new modifiers makes sense and update this test. " - "Expect: ", - expect, - " Get: ", - i.modifiers()); - found_BAR = true; - } else if (i.opCode() == "DEPBAR") { - const std::vector expect = {"LE"}; - TORCH_CHECK( - i.modifiers() == expect, - "Modifiers for DEPBAR has changed. " - "Please manually check if the new modifiers makes sense and update this test. " - "Expect: ", - expect, - " Get: ", - i.modifiers()); - found_DEPBAR = true; + for (auto use_shared_epilogue : {true, false}) { + for (auto layout : {MatmulLayout::TT}) { + bool found_LDGSTS = false; + bool found_LDSM = false; + bool found_HMMA = false; + bool found_LDGDEPBAR = false; + bool found_DEPBAR = false; // kAllSupportedMatmulLayout; + int BAR_COUNT = 0; + // we have three shared memory barriers in the kernel if + // use_shared_epilogue + const int EXPECTED_BAR_COUNT = use_shared_epilogue ? 3 : 2; + sass::Container sass; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + sass = getSASSFor( + layout, + GemmTile(128, 128, 32), + GemmTile(64, 64, 32), + GemmTile(16, 8, 16), + MmaOptions::MacroType::Ampere_16_8_16, + M, + N, + K, + use_shared_epilogue)); + for (auto inst : sass.code) { + std::visit( + [&](auto&& i) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + if (i.opCode() == "LDGSTS") { + const std::vector expect = { + "E", "BYPASS", "LTC128B", "128"}; + TORCH_CHECK( + i.modifiers() == expect, + "Modifiers for LDGSTS has changed. " + "Please manually check if the new modifiers makes sense and update this test. " + "Expect: ", + expect, + " Get: ", + i.modifiers()); + found_LDGSTS = true; + } else if (i.opCode() == "LDGDEPBAR") { + const std::vector expect; + TORCH_CHECK( + i.modifiers() == expect, + "Modifiers for LDGDEPBAR has changed. " + "Please manually check if the new modifiers makes sense and update this test. " + "Expect: ", + expect, + " Get: ", + i.modifiers()); + found_LDGDEPBAR = true; + } else if (i.opCode() == "LDSM") { + const std::vector expect1 = {"16", "M88", "2"}; + const std::vector expect2 = {"16", "M88", "4"}; + const std::vector expect3 = {"16", "MT88", "2"}; + const std::vector expect4 = {"16", "MT88", "4"}; + TORCH_CHECK( + i.modifiers() == expect1 || i.modifiers() == expect2 || + i.modifiers() == expect3 || i.modifiers() == expect4, + "Modifiers for LDGDEPBAR has changed. " + "Please manually check if the new modifiers makes sense and update this test."); + found_LDSM = true; + } else if (i.opCode() == "HMMA") { + const std::vector expect = {"16816", "F32"}; + TORCH_CHECK( + i.modifiers() == expect, + "Modifiers for HMMA has changed. " + "Please manually check if the new modifiers makes sense and update this test. " + "Expect: ", + expect, + " Get: ", + i.modifiers()); + found_HMMA = true; + } else if (i.opCode() == "BAR") { + const std::vector expect = { + "SYNC", "DEFER_BLOCKING"}; + TORCH_CHECK( + i.modifiers() == expect, + "Modifiers for BAR has changed. " + "Please manually check if the new modifiers makes sense and update this test. " + "Expect: ", + expect, + " Get: ", + i.modifiers()); + BAR_COUNT++; + } else if (i.opCode() == "DEPBAR") { + const std::vector expect = {"LE"}; + TORCH_CHECK( + i.modifiers() == expect, + "Modifiers for DEPBAR has changed. " + "Please manually check if the new modifiers makes sense and update this test. " + "Expect: ", + expect, + " Get: ", + i.modifiers()); + found_DEPBAR = true; + } } - } - }, - inst); + }, + inst); + } + TORCH_CHECK(found_LDGSTS); + TORCH_CHECK(found_LDSM); + TORCH_CHECK(found_HMMA); + TORCH_CHECK(found_LDGDEPBAR); + TORCH_CHECK( + BAR_COUNT == EXPECTED_BAR_COUNT, + "Expect ", + EXPECTED_BAR_COUNT, + " BARs, got ", + BAR_COUNT); + TORCH_CHECK(found_DEPBAR); } - TORCH_CHECK(found_LDGSTS); - TORCH_CHECK(found_LDSM); - TORCH_CHECK(found_HMMA); - TORCH_CHECK(found_LDGDEPBAR); - TORCH_CHECK(found_BAR); - TORCH_CHECK(found_DEPBAR); } }