diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index d8079f26d86..3148d22b7ff 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -509,6 +509,15 @@ void swizzleSharedMemory( void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) { shared_mem_tv->setMemoryType(MemoryType::Shared); + // The following line allows us to reclaim the memory allocated to + // shared_mem_tv and reuse it for the epilogue, introducing one block sync if + // needed. This is not done by default as we do not insert new syncs unless + // requested to do so. If smem is not used for the epilogue, this call will + // have no effect. + if (params.promote_prologue_smem_reuse) { + shared_mem_tv->promoteReuse(); + } + mma_utils::orderTiledConcreteIdAsRoot(shared_mem_tv); // Swizzle the shared memory data layout diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index 14fea9db17c..36a82c2b90f 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -94,6 +94,9 @@ class MatmulParams : public HeuristicParams { //! coalesced write to global memory bool use_smem_epilogue = false; + //! Promote reuse of prologue shared memory + bool promote_prologue_smem_reuse = false; + std::string toString() const override { std::stringstream ss; ss << "\n===== Matmul Parameters ========\n" @@ -117,13 +120,16 @@ class MatmulParams : public HeuristicParams { << "\n" << "Grid swizzle factor: " << grid_swizzle_factor << "\n" << "Use shared memory epilogue: " << use_smem_epilogue << "\n" + << "Promote re-use of prologue shared memory: " + << promote_prologue_smem_reuse << "\n" << "====================================\n"; return ss.str(); } size_t hash() const override { // combine boolean flags for hashing - size_t attr_hash = + size_t attr_hash = (static_cast(promote_prologue_smem_reuse) << 3) | + (static_cast(use_smem_epilogue) << 2) | (static_cast(rotate_ldmatrix_out_of_main_loop) << 1) | (static_cast(async_gmem_load_operands)); @@ -150,7 +156,10 @@ class MatmulParams : public HeuristicParams { other_casted->tile_sizes == tile_sizes && other_casted->double_buffer_options == double_buffer_options && other_casted->cta_order == cta_order && - other_casted->grid_swizzle_factor == grid_swizzle_factor; + other_casted->grid_swizzle_factor == grid_swizzle_factor && + other_casted->use_smem_epilogue == use_smem_epilogue && + other_casted->promote_prologue_smem_reuse == + promote_prologue_smem_reuse; } std::shared_ptr clone() const override { diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 99793578f83..78fb16afb09 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -143,23 +143,6 @@ inline bool initCoreHeuristics( return true; } -//! A wrapper to get MMA Tensor data types -//! The order of returned types: INPUT_A, INPUT_B, OUTPUT_D -inline mma_utils::MmaDataTypes getMmaDataTypes( - const std::map>& roles_map) { - auto getMMADataType = [&](MatmulRole role) { - auto entry = roles_map.find(role); - if (entry != roles_map.end() && !entry->second.empty()) { - return entry->second.front()->dtype(); - } - TORCH_INTERNAL_ASSERT(false, "Get MMA Tensor data type failed!"); - }; - const auto a_type = getMMADataType(MatmulRole::INPUT_A); - const auto b_type = getMMADataType(MatmulRole::INPUT_B); - const auto c_type = getMMADataType(MatmulRole::OUTPUT_D); - return mma_utils::MmaDataTypes{a_type, b_type, c_type}; -} - //! A helper for getting problem shape from fusion and runtime info. ProblemShape getProblemShape( Fusion* fusion, @@ -416,10 +399,13 @@ std::shared_ptr getMatmulHeuristics( const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion); TORCH_INTERNAL_ASSERT( roles_map_opt.isValid(), "Tensor roles map in mma is not valid."); - params->use_smem_epilogue = mma_utils::generateSharedMemoryEpilogueHeuristics( - params->tile_sizes, - params->double_buffer_options.smem_double_buffer_stage, - getMmaDataTypes(roles_map_opt.getData())); + + const auto roles_map = roles_map_opt.getData(); + std::tie(params->use_smem_epilogue, params->promote_prologue_smem_reuse) = + mma_utils::generateSharedMemoryEpilogueHeuristics( + params->tile_sizes, + params->double_buffer_options.smem_double_buffer_stage, + roles_map); if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { debug() << params->toString() << std::endl; diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 546e600908f..e1e2eebe270 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -19,11 +19,30 @@ namespace nvfuser { namespace mma_utils { -bool generateSharedMemoryEpilogueHeuristics( +//! A wrapper to get MMA Tensor data types +//! The order of returned types: INPUT_A, INPUT_B, OUTPUT_D +inline mma_utils::MmaDataTypes getMmaDataTypes( + const std::map>& roles_map) { + auto getMMADataType = [&](MatmulRole role) { + auto entry = roles_map.find(role); + if (entry != roles_map.end() && !entry->second.empty()) { + return entry->second.front()->dtype(); + } + TORCH_INTERNAL_ASSERT(false, "Get MMA Tensor data type failed!"); + }; + const auto a_type = getMMADataType(MatmulRole::INPUT_A); + const auto b_type = getMMADataType(MatmulRole::INPUT_B); + const auto c_type = getMMADataType(MatmulRole::OUTPUT_D); + return mma_utils::MmaDataTypes{a_type, b_type, c_type}; +} + +std::pair generateSharedMemoryEpilogueHeuristics( const MatMulTileOptions& gemm_tile, const int smem_double_buffer_stage, const MmaDataTypes& data_types, - const bool ignore_occupancy_drop) { + bool smem_a_reuse_guaranteed, + bool smem_b_reuse_guaranteed, + bool ignore_occupancy_drop) { const auto properties = at::cuda::getCurrentDeviceProperties(); const size_t device_smem_limit = properties->sharedMemPerBlockOptin; const size_t shared_memory_overhead = properties->reservedSharedMemPerBlock; @@ -49,9 +68,30 @@ bool generateSharedMemoryEpilogueHeuristics( const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * dataTypeSize(data_types[2]); + // NOTE: we can simply add these sizes since they should be integer multiples + // of 16 bytes, so they will automatically be aligned. This may change with + // FP8, in which case the expressions below should be updated to insert + // alignment expressions, using the expected stack ordering in + // StackBasedSharedMemAllocator. + TORCH_CHECK(smem_a % 16 == 0 && smem_b % 16 == 0 && smem_b % 16 == 0); + + const size_t total_without_smem_epilogue = smem_a + smem_b; + const size_t total_with_noreuse_smem_epilogue = smem_a + smem_b + smem_c; + // Even if we actually do wind up re-claiming smem_a and smem_b, if we + // cannot prove it at this point then we have to assume it will not be + // reclaimed. + const size_t total_with_reused_smem_epilogue = std::max( + smem_a + smem_b, + (smem_a_reuse_guaranteed ? 0 : smem_a) + + (smem_b_reuse_guaranteed ? 0 : smem_b) + smem_c); + // shortcut where occupancy change is ignored. if (ignore_occupancy_drop) { - return shared_memory_available >= smem_a + smem_b + smem_c; + if (shared_memory_available >= total_with_noreuse_smem_epilogue) { + return {true, false}; + } else { + return {shared_memory_available >= total_with_reused_smem_epilogue, true}; + } } // use additional shared memory for epilogue if occupancy is not changed. @@ -59,14 +99,70 @@ bool generateSharedMemoryEpilogueHeuristics( const auto threads_per_sm = getThreadsPerSMGivenRegPerThread(255); const auto blocks_per_sm_by_register = threads_per_sm / threads_per_block; const auto blocks_per_sm_without_smem_epilogue = std::min( - shared_memory_available / (smem_a + smem_b), + shared_memory_available / total_without_smem_epilogue, (size_t)blocks_per_sm_by_register); - const auto blocks_per_sm_with_smem_epilogue = std::min( - shared_memory_available / (smem_a + smem_b + smem_c), + const auto blocks_per_sm_with_reused_smem_epilogue = std::min( + shared_memory_available / total_with_reused_smem_epilogue, (size_t)blocks_per_sm_by_register); + const auto blocks_per_sm_with_noreuse_smem_epilogue = std::min( + shared_memory_available / total_with_noreuse_smem_epilogue, + (size_t)blocks_per_sm_by_register); + + // Return whether we should use smem for epilogue, and whether syncing for + // re-use is desired. We avoid the sync if omitting it does not decrease + // occupancy. + auto promote_prologue_smem_reuse = blocks_per_sm_with_reused_smem_epilogue != + blocks_per_sm_with_noreuse_smem_epilogue; - return blocks_per_sm_with_smem_epilogue == - blocks_per_sm_without_smem_epilogue; + return { + blocks_per_sm_with_reused_smem_epilogue == + blocks_per_sm_without_smem_epilogue, + promote_prologue_smem_reuse}; +} + +std::pair generateSharedMemoryEpilogueHeuristics( + const MatMulTileOptions& gemm_tile, + const int smem_double_buffer_stage, + const RolesMap& roles_map, + const bool ignore_occupancy_drop) { + const auto data_types = getMmaDataTypes(roles_map); + + // smem_a and smem_b are guaranteed to be re-used for smem_c as long as: + // - they are marked for re-use using promoteReuse + // - they are not aliased by another tensor whose lifetime extends past the + // start of smem_epilogue's. + // - their lifetimes do not overlap smem_epilogue + // + // We can guarantee the first condition by calling tv->promoteReuse() in + // scheduleProlog. + // + // The second condition would only be the case if another smem tensor had the + // same indexing and its lifetime did not overlap. Matmul scheduler only uses + // smem for these three arrays, so the only candidate for aliasing is C. If C + // aliases either A or B, the following expression is still valid. + // + // The third condition is satisfied in the simple cases where the inputs to + // the matmul have only this use. However, it could be violated if a or b has + // other uses that get ordered after the matmul; for example when computing + // matmul(A, B) + A for square matrices A and B. In that case, the smem tensor + // resulting from A->cacheAfter() will be used in both the matmul as well as + // the addition that occurs in the epilogue, extending the lifetime such that + // it violates the third condition above. In order to avoid errors in these + // cases, we check that there is no re-use when there is more than one use of + // either a or b. If there are multiple uses we might wind up re-using memory, + // but in that case the calculation below will be overly conservative. + TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); + TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); + bool smem_a_reuse_guaranteed = a->uses().size() == 1; + bool smem_b_reuse_guaranteed = b->uses().size() == 1; + + return generateSharedMemoryEpilogueHeuristics( + gemm_tile, + smem_double_buffer_stage, + data_types, + smem_a_reuse_guaranteed, + smem_b_reuse_guaranteed, + ignore_occupancy_drop); } void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 773fac1da3b..2fe875696f6 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -293,19 +293,34 @@ TORCH_CUDA_CU_API ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion); //! be gathered. TORCH_CUDA_CU_API RolesMapOpt getTensorsRoles(Fusion* fusion); -//! Return whether use shared memory epilogue or not. -//! Returns true if using shared memory epilogue won't cause -//! the decrease of occupancy ratio. The occupancy ratio is -//! estimated using register and shared memory usage. -//! If ignore_occupancy_drop is set to true, returns true if -//! there is enough shared memory to launch the kernel without -//! considering the occupancy, useful for debug and validate -//! shared memory epilogue implementation. -TORCH_CUDA_CU_API bool generateSharedMemoryEpilogueHeuristics( +//! Return pair of whether use shared memory epilogue or not and whether to +//! reuse shared memory for the prologue at the expense of an additional block +//! sync. +//! +//! Returns true in first position if using shared memory epilogue won't cause +//! the decrease of occupancy ratio. The occupancy ratio is estimated using +//! register and shared memory usage. If ignore_occupancy_drop is set to true, +//! returns true if there is enough shared memory to launch the kernel without +//! considering the occupancy, useful for debug and validate shared memory +//! epilogue implementation. +//! +//! Returns true in the second position if reusing shared memory for the +//! epilogue does not increase occupancy. +TORCH_CUDA_CU_API std::pair generateSharedMemoryEpilogueHeuristics( + const MatMulTileOptions& gemm_tile, + const int smem_double_buffer_stage, + const RolesMap& roles_map, + bool ignore_occupancy_drop = false); + +//! This version assumes roles_map has been analyzed to determine smem datatypes +//! as well as guarantees about prologue smem reuse. +TORCH_CUDA_CU_API std::pair generateSharedMemoryEpilogueHeuristics( const MatMulTileOptions& gemm_tile, const int smem_double_buffer_stage, const MmaDataTypes& data_types, - const bool ignore_occupancy_drop = false); + bool smem_a_reuse_guaranteed = false, + bool smem_b_reuse_guaranteed = false, + bool ignore_occupancy_drop = false); } // namespace mma_utils diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index e8f882bceb0..6557b825f93 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -3258,7 +3258,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; - params.use_smem_epilogue = + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, @@ -3326,7 +3326,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; - params.use_smem_epilogue = + std::tie( + params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, @@ -3389,7 +3390,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; - params.use_smem_epilogue = + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, @@ -3938,7 +3939,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; - params.use_smem_epilogue = + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, @@ -3991,6 +3992,47 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { GTEST_SKIP() << "Test conducted without utilizing shared memory epilogue due to the device's constrained shared memory capacity."; } + + // Check that smem is allocated as expected. + // There are three cases that are determined by the current device in + // mma_utils::generateSharedMemoryEpilogueHeuristics: + // - !use_smem_epilogue : A + B (this test is skipped in this case) + // - use_smem_epilogue && !promote_prologue_smem_reuse : A + B + C + // - use_smem_epilogue && promote_prologue_smem_reuse : max(A + B, C) + auto smem_allocs = fe.kernel()->summary().dynamic_smem_allocations; + TORCH_CHECK(smem_allocs.size() == 3); + if (params.promote_prologue_smem_reuse) { + // Check prologue shared memory re-use + // smem_allocs = {A, B, C} where C is the epilogue buffer + // since A and B have no further uses, we should be able to reuse both + // of them, implying that the address of C is zero. In this case, B will + // also be allocated at address 0 with A stacked above it at position + // 8192. + EXPECT_EQ( + smem_allocs.at(0)->address()->evaluateInt(), + // Assuming B numel times size(dtype) is a multiple of 16 so that + // this address is aligned + smem_allocs.at(1)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(1)->buffer()->dtype())); + EXPECT_EQ(smem_allocs.at(1)->address()->evaluateInt(), 0L); + EXPECT_EQ(smem_allocs.at(2)->address()->evaluateInt(), 0L); + } else { + // Prologue shared memory is not re-used. In this case, memory should + // stack in C, B, A order. + EXPECT_EQ( + smem_allocs.at(0)->address()->evaluateInt(), + // Assuming for B and C that numel times size(dtype) is a multiple + // of 16 so that this address is aligned + smem_allocs.at(1)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(1)->buffer()->dtype()) + + smem_allocs.at(2)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(2)->buffer()->dtype())); + EXPECT_EQ( + smem_allocs.at(1)->address()->evaluateInt(), + smem_allocs.at(2)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(2)->buffer()->dtype())); + EXPECT_EQ(smem_allocs.at(2)->address()->evaluateInt(), 0L); + } } } @@ -4025,7 +4067,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueCast_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 4; - params.use_smem_epilogue = + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, @@ -4112,7 +4154,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueRelu_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 4; - params.use_smem_epilogue = + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, diff --git a/test/test_matmul_sass.cpp b/test/test_matmul_sass.cpp index 36426dfff82..92033f2100f 100644 --- a/test/test_matmul_sass.cpp +++ b/test/test_matmul_sass.cpp @@ -46,7 +46,8 @@ sass::Container getSASSFor( int N, int K, const int smem_double_buffer_stage = 4, - const bool use_smem_epilogue = false) { + const bool use_smem_epilogue = false, + const bool promote_prologue_smem_reuse = false) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -73,6 +74,7 @@ sass::Container getSASSFor( params.double_buffer_options.smem_double_buffer_stage = smem_double_buffer_stage; params.use_smem_epilogue = use_smem_epilogue; + params.promote_prologue_smem_reuse = promote_prologue_smem_reuse; scheduleMatmul(&fusion, params); auto inputs = matmulAtInput(M, N, K, layout); @@ -328,7 +330,7 @@ TEST_F(MatmulSASSTest, AmpereModifiersSharedMemoryEpilogue_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); const int smem_double_buffer_stage = 4; const bool ignore_occupancy_drop = true; - const bool use_smem_epilogue = + const auto [use_smem_epilogue, promote_prologue_smem_reuse] = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, smem_double_buffer_stage, @@ -347,9 +349,9 @@ TEST_F(MatmulSASSTest, AmpereModifiersSharedMemoryEpilogue_CUDA) { bool found_LDGDEPBAR = false; bool found_DEPBAR = false; // kAllSupportedMatmulLayout; int BAR_COUNT = 0; - // we have three shared memory barriers in the kernel if - // use_shared_epilogue - const int EXPECTED_BAR_COUNT = 3; + // we have at least three shared memory barriers in the kernel if + // use_shared_epilogue. If promote_prologue_smem_reuse, then 4 + const int EXPECTED_BAR_COUNT = promote_prologue_smem_reuse ? 4 : 3; sass::Container sass; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, @@ -364,7 +366,8 @@ TEST_F(MatmulSASSTest, AmpereModifiersSharedMemoryEpilogue_CUDA) { N, K, smem_double_buffer_stage, - use_smem_epilogue)); + use_smem_epilogue, + promote_prologue_smem_reuse)); for (auto inst : sass.code) { std::visit( [&](auto&& i) {