From 9c06a18a73d4f2ebd6f2f77788734eecd53371ab Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 19:30:56 -0400 Subject: [PATCH 01/30] Ignore trivial loops in memory aliasing pass --- csrc/device_lower/pass/alias_memory.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index b1f17c436a9..1367761f89b 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -743,7 +743,13 @@ class AllocationInfoMap : private kir::IrVisitor { void handle(kir::ForLoop* for_loop) final { auto loop_info = scope_map_.getLoopScopeInfo(for_loop); - current_stack_.push_back(loop_info); + if (!for_loop->isTrivial()) { + // Parallelized loops do not result in for loops in the CUDA kernel, so + // they should not affect liveness analysis. This means that + // current_stack_ will differ from kir::IrVisitor::for_loops_, which will + // actually hold all ForLoops regardless of parallelization. + current_stack_.push_back(loop_info); + } if (debug_printer_) { debug_printer_->pushScope(); } @@ -751,7 +757,9 @@ class AllocationInfoMap : private kir::IrVisitor { if (debug_printer_) { debug_printer_->popScope(); } - current_stack_.pop_back(); + if (!for_loop->isTrivial()) { + current_stack_.pop_back(); + } } void handle(kir::IfThenElse* ite) final { From 8a5dafe5496224e05d01968320b39a90c88b3d76 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 17 Aug 2023 11:30:18 -0400 Subject: [PATCH 02/30] Add TensorView::requestReuse and failing test --- test/test_smem_reuse.cpp | 73 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index e891566f70f..d75827fab3a 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -411,4 +411,77 @@ TEST_F(SmemReuseTest, MultiplePromoteReuse) { } } +// Same as NeedsReorderedPush but C requests to reuse A instead of pre-existing +// sync +TEST_F(SmemReuseTest, RequestReuse) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t H_int = 5, W_int = 6; + auto H = IrBuilder::create(H_int); + auto W = IrBuilder::create(W_int); + + auto tv0 = full({H}, fusion->oneVal(), DataType::Float); + auto tv1 = set(tv0); // pos = a. A = tv1 + tv1->setMemoryType(MemoryType::Shared); + + auto tv2 = full({W}, fusion->oneVal(), DataType::Float); // pos = b. B = tv2 + tv2->setMemoryType(MemoryType::Shared); + + auto tv3 = add(tv1, tv1); // pos = c + + auto tv4 = sum(tv3, {0}); // gap between b and c + fusion->addOutput(tv4); + + auto tv5 = broadcast(tv4, {true}); + auto tv6 = mul(tv5, tv3); + + auto tv7 = broadcast(tv6, {true, false}); + auto tv8 = broadcast(tv2, {false, true}); + auto tv9 = mul(tv7, tv8); // pos = d. C = tv9 + tv9->setMemoryType(MemoryType::Shared); + + auto tv10 = add(tv2, tv2); // pos = e + fusion->addOutput(tv10); + + auto tv11 = neg(tv9); // pos = f + fusion->addOutput(tv11); + + { // This should not re-use memory + GpuLower gpulw(fusion.get()); + + ExpressionEvaluator ee; + std::unordered_set addresses; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + auto addr = ee.evaluate(alloc->address()).as(); + TORCH_CHECK( + addresses.insert(addr).second, + "Smem addresses should not be re-used"); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ( + smem_usage, + alignInt(alignInt(H_int * W_int * 4) + W_int * 4) + H_int * 4); + } + + { // Now introduce a block reduction and check that we re-use memory + + tv9->requestReuse(tv1); + + GpuLower gpulw(fusion.get()); + ExpressionEvaluator ee; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + auto addr = ee.evaluate(alloc->address()).as(); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ(smem_usage, alignInt(H_int * 4) + W_int * H_int * 4); + } +} + } // namespace nvfuser From 047e2607365cfb4204e0a1a67110c107d5db053a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 18 Aug 2023 08:20:52 -0400 Subject: [PATCH 03/30] Don't count parallelized loops in AllocationInfoMap Previously, we stacked every ForLoop regardless of parallelization. This meant that when the first few dimensions were left of compute at in the whole fusion, even if they were parallelized all tensors would have the same outer live interval. I noticed this for the AmpereMatmulSmemEpilogue_CUDA tests. In that case if you look at the generated CUDA it's clearly not true; the outer for loops do not appear since they are parallelized. This commit fixes this; note that it can affect all reuse analysis including aliasing even of local memory. --- csrc/device_lower/pass/alias_memory.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 0fbe3a65240..4305bec96f3 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -763,7 +763,13 @@ class AllocationInfoMap : private kir::IrVisitor { void handle(kir::ForLoop* for_loop) final { auto loop_info = scope_map_.getLoopScopeInfo(for_loop); - current_stack_.push_back(loop_info); + if (!for_loop->iter_domain()->isParallelized()) { + // Parallelized loops do not result in for loops in the CUDA kernel, so + // they should not affect liveness analysis. This means that + // current_stack_ will differ from kir::IrVisitor::for_loops_, which will + // actually hold all ForLoops regardless of parallelization. + current_stack_.push_back(loop_info); + } if (debug_printer_) { debug_printer_->pushScope(); } @@ -771,7 +777,9 @@ class AllocationInfoMap : private kir::IrVisitor { if (debug_printer_) { debug_printer_->popScope(); } - current_stack_.pop_back(); + if (!for_loop->iter_domain()->isParallelized()) { + current_stack_.pop_back(); + } } void handle(kir::IfThenElse* ite) final { From 4d9cc351a72018700ff1437ba3eea5718a4a00b4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 18 Aug 2023 13:16:01 -0400 Subject: [PATCH 04/30] Fix up reuse tests --- test/test_smem_reuse.cpp | 53 +++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index d75827fab3a..fd57f454fca 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -417,35 +417,31 @@ TEST_F(SmemReuseTest, RequestReuse) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - int64_t H_int = 5, W_int = 6; - auto H = IrBuilder::create(H_int); - auto W = IrBuilder::create(W_int); - - auto tv0 = full({H}, fusion->oneVal(), DataType::Float); - auto tv1 = set(tv0); // pos = a. A = tv1 - tv1->setMemoryType(MemoryType::Shared); + int64_t H = 5; - auto tv2 = full({W}, fusion->oneVal(), DataType::Float); // pos = b. B = tv2 - tv2->setMemoryType(MemoryType::Shared); + auto tv0 = full( + {IrBuilder::create(H)}, + fusion->oneVal(), + DataType::Float); // pos = a. A = tv0 + tv0->setMemoryType(MemoryType::Shared); - auto tv3 = add(tv1, tv1); // pos = c + auto tv1 = + pad(tv0, {fusion->zeroVal(), fusion->oneVal()}); // pos = b. B = tv1 + tv1->setMemoryType(MemoryType::Shared); - auto tv4 = sum(tv3, {0}); // gap between b and c - fusion->addOutput(tv4); + auto tv2 = mul(tv1, tv1); // pos = c - auto tv5 = broadcast(tv4, {true}); - auto tv6 = mul(tv5, tv3); + auto tv3 = sum(tv2, {0}); // gap between b and c. Can parallelize to sync - auto tv7 = broadcast(tv6, {true, false}); - auto tv8 = broadcast(tv2, {false, true}); - auto tv9 = mul(tv7, tv8); // pos = d. C = tv9 - tv9->setMemoryType(MemoryType::Shared); + auto tv4 = broadcast(tv3, {true}); + auto tv5 = mul(tv4, tv1); // pos = d. C = tv5 + tv5->setMemoryType(MemoryType::Shared); - auto tv10 = add(tv2, tv2); // pos = e - fusion->addOutput(tv10); + auto tv6 = add(tv1, tv1); // pos = e + fusion->addOutput(tv6); - auto tv11 = neg(tv9); // pos = f - fusion->addOutput(tv11); + auto tv7 = neg(tv5); // pos = f + fusion->addOutput(tv7); { // This should not re-use memory GpuLower gpulw(fusion.get()); @@ -454,6 +450,7 @@ TEST_F(SmemReuseTest, RequestReuse) { std::unordered_set addresses; int64_t smem_usage = 0; for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); auto addr = ee.evaluate(alloc->address()).as(); TORCH_CHECK( addresses.insert(addr).second, @@ -463,24 +460,24 @@ TEST_F(SmemReuseTest, RequestReuse) { smem_usage = std::max(smem_usage, addr + size); } EXPECT_EQ( - smem_usage, - alignInt(alignInt(H_int * W_int * 4) + W_int * 4) + H_int * 4); + smem_usage, alignInt(alignInt((H + 1) * 4) + (H + 1) * 4) + H * 4); } - { // Now introduce a block reduction and check that we re-use memory - - tv9->requestReuse(tv1); + { // Request a sync between tv0 and tv5. This will place a __syncthreads just + // before tv5 is written. + tv5->requestReuse(tv0); GpuLower gpulw(fusion.get()); ExpressionEvaluator ee; int64_t smem_usage = 0; for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); auto addr = ee.evaluate(alloc->address()).as(); auto size = ee.evaluate(alloc->size()).as() * dataTypeSize(alloc->buffer()->dtype()); smem_usage = std::max(smem_usage, addr + size); } - EXPECT_EQ(smem_usage, alignInt(H_int * 4) + W_int * H_int * 4); + EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + (H + 1) * 4); } } From 6a82eb4c1ea041ff129ab84fb014400f46129c8d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 18 Aug 2023 13:27:25 -0400 Subject: [PATCH 05/30] Enable smem reuse in matmul epilogue --- csrc/scheduler/matmul.cpp | 8 ++++++++ csrc/scheduler/mma_utils.cpp | 10 +++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index d8079f26d86..f8900ecdca2 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -968,6 +968,14 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { if (params.use_smem_epilogue) { smem_epilogue->setMemoryType(MemoryType::Shared); + + // The following line allows us to reclaim the memory allocated to acw_smem + // and bcw_smem and reuse it for smem_epilogue, introducing one block sync. + // This is not done by default as we do not insert new syncs unless + // requested to do so. Note that bcw_smem's lifetime overlaps acw_smem's, so + // it will also be reclaimed, even though we do not explicitly request that + smem_epilogue->requestReuse(acw_smem); + swizzleSharedMemory(smem_epilogue, params); scheduler_utils::BoundedDirectionalTransformPropagator::forward( mma_result, diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 546e600908f..6ab705edf14 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -49,9 +49,13 @@ bool generateSharedMemoryEpilogueHeuristics( const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * dataTypeSize(data_types[2]); + const size_t total_without_smem_epilogue = smem_a + smem_b; + // Note that we reclaim smem_a and smem_b before allocating smem_c + const size_t total_with_smem_epilogue = std::max(smem_a + smem_b, smem_c); + // shortcut where occupancy change is ignored. if (ignore_occupancy_drop) { - return shared_memory_available >= smem_a + smem_b + smem_c; + return shared_memory_available >= total_with_smem_epilogue; } // use additional shared memory for epilogue if occupancy is not changed. @@ -59,10 +63,10 @@ bool generateSharedMemoryEpilogueHeuristics( const auto threads_per_sm = getThreadsPerSMGivenRegPerThread(255); const auto blocks_per_sm_by_register = threads_per_sm / threads_per_block; const auto blocks_per_sm_without_smem_epilogue = std::min( - shared_memory_available / (smem_a + smem_b), + shared_memory_available / total_without_smem_epilogue, (size_t)blocks_per_sm_by_register); const auto blocks_per_sm_with_smem_epilogue = std::min( - shared_memory_available / (smem_a + smem_b + smem_c), + shared_memory_available / total_with_smem_epilogue, (size_t)blocks_per_sm_by_register); return blocks_per_sm_with_smem_epilogue == From fd5a7db14912494ce8dfe18ef81fe5427e8fd80d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 07:32:01 -0400 Subject: [PATCH 06/30] Change to promoteReuse interface --- csrc/device_lower/pass/alias_memory.cpp | 27 +++++++++++++++++++++++++ test/test_smem_reuse.cpp | 8 +++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 4305bec96f3..85b0847970e 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1629,6 +1629,20 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { void allocate(const std::vector& exprs) { recordEvents(); + for (auto& [pos, fws] : first_write_positions_) { + std::cout << "Position " << pos + << " is first write pos for:" << std::endl; + for (auto& alloc_info : fws) { + std::cout << " " << alloc_info->alloc_expr->buffer()->toString() + << std::endl; + } + } + + for (auto& pos : last_read_positions_) { + std::cout << "Position " << pos + << " is last aliased read pos for some allocs" << std::endl; + } + // Traverse expressions: reclaim memory when we pass a blockSync, append to // waiting_to_push_ when we pass an Allocate handle(exprs); @@ -1642,12 +1656,21 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { void dispatch(Expr* expr) final { position_ = allocation_info_map_.getScopeMap().getExprPos(expr); + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Position " << position_ << std::endl; + } + // Check whether this is a first write position for any allocations auto it = first_write_positions_.find(position_); if (it != first_write_positions_.end()) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Position " << position_ << " is first write for"; + } for (auto alloc_info : it->second) { + debug() << " T" << alloc_info->alloc_expr->buffer()->name(); waiting_to_push_.push_back(alloc_info); } + debug() << std::endl; } // Reclaim memory whenever we pass an Expr that is known to synchronize the @@ -1730,10 +1753,14 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { //! Record first reads and last writes, respecting aliased buffers void recordEvents() { for (auto& alloc_info : allocation_info_map_.allAllocationInfos()) { + std::cout << "alloc info found for " << alloc_info->alloc_expr->toString() + << std::endl; if (alloc_info->mem_type != MemoryType::Shared) { continue; } if (alloc_info->alias_to) { + std::cout << " Allocation aliases" << alloc_info->alias_to->toString() + << std::endl; auto alias_info = allocation_info_map_.getAllocationInfo(alloc_info->alias_to); TORCH_CHECK(alias_info); diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index fd57f454fca..0ee2e6b4685 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -463,9 +463,9 @@ TEST_F(SmemReuseTest, RequestReuse) { smem_usage, alignInt(alignInt((H + 1) * 4) + (H + 1) * 4) + H * 4); } - { // Request a sync between tv0 and tv5. This will place a __syncthreads just - // before tv5 is written. - tv5->requestReuse(tv0); + { // Request that we re-use the allocation for tv0. This should place a + // __syncthreads() just before tv5 is written. + tv0->promoteReuse(); GpuLower gpulw(fusion.get()); ExpressionEvaluator ee; @@ -481,4 +481,6 @@ TEST_F(SmemReuseTest, RequestReuse) { } } +// TODO: Test involving requested reuse along with automatic aliasing + } // namespace nvfuser From df5476bb0384de6017b69c6e5c898f43b9dfa31a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 10:15:38 -0400 Subject: [PATCH 07/30] Remove prints and fix tv clone --- csrc/device_lower/pass/alias_memory.cpp | 27 ------------------------- 1 file changed, 27 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 85b0847970e..4305bec96f3 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1629,20 +1629,6 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { void allocate(const std::vector& exprs) { recordEvents(); - for (auto& [pos, fws] : first_write_positions_) { - std::cout << "Position " << pos - << " is first write pos for:" << std::endl; - for (auto& alloc_info : fws) { - std::cout << " " << alloc_info->alloc_expr->buffer()->toString() - << std::endl; - } - } - - for (auto& pos : last_read_positions_) { - std::cout << "Position " << pos - << " is last aliased read pos for some allocs" << std::endl; - } - // Traverse expressions: reclaim memory when we pass a blockSync, append to // waiting_to_push_ when we pass an Allocate handle(exprs); @@ -1656,21 +1642,12 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { void dispatch(Expr* expr) final { position_ = allocation_info_map_.getScopeMap().getExprPos(expr); - if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Position " << position_ << std::endl; - } - // Check whether this is a first write position for any allocations auto it = first_write_positions_.find(position_); if (it != first_write_positions_.end()) { - if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Position " << position_ << " is first write for"; - } for (auto alloc_info : it->second) { - debug() << " T" << alloc_info->alloc_expr->buffer()->name(); waiting_to_push_.push_back(alloc_info); } - debug() << std::endl; } // Reclaim memory whenever we pass an Expr that is known to synchronize the @@ -1753,14 +1730,10 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { //! Record first reads and last writes, respecting aliased buffers void recordEvents() { for (auto& alloc_info : allocation_info_map_.allAllocationInfos()) { - std::cout << "alloc info found for " << alloc_info->alloc_expr->toString() - << std::endl; if (alloc_info->mem_type != MemoryType::Shared) { continue; } if (alloc_info->alias_to) { - std::cout << " Allocation aliases" << alloc_info->alias_to->toString() - << std::endl; auto alias_info = allocation_info_map_.getAllocationInfo(alloc_info->alias_to); TORCH_CHECK(alias_info); From 2a1b51a2b24158ada8ae9c1e6168d076a67487d5 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 10:26:42 -0400 Subject: [PATCH 08/30] Clean up comment --- test/test_smem_reuse.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 0ee2e6b4685..485a570066c 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -481,6 +481,4 @@ TEST_F(SmemReuseTest, RequestReuse) { } } -// TODO: Test involving requested reuse along with automatic aliasing - } // namespace nvfuser From 263f42895a2148d79f669619dcb2968b5f87f452 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 10:43:00 -0400 Subject: [PATCH 09/30] Switch to using promoteReuse --- csrc/scheduler/matmul.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index f8900ecdca2..1fafa8e9a42 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -509,6 +509,12 @@ void swizzleSharedMemory( void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) { shared_mem_tv->setMemoryType(MemoryType::Shared); + // The following line allows us to reclaim the memory allocated to + // shared_mem_tv and reuse it for the epilogue, introducing one block sync. + // This is not done by default as we do not insert new syncs unless requested + // to do so. + shared_mem_tv->promoteReuse(); + mma_utils::orderTiledConcreteIdAsRoot(shared_mem_tv); // Swizzle the shared memory data layout @@ -969,13 +975,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { if (params.use_smem_epilogue) { smem_epilogue->setMemoryType(MemoryType::Shared); - // The following line allows us to reclaim the memory allocated to acw_smem - // and bcw_smem and reuse it for smem_epilogue, introducing one block sync. - // This is not done by default as we do not insert new syncs unless - // requested to do so. Note that bcw_smem's lifetime overlaps acw_smem's, so - // it will also be reclaimed, even though we do not explicitly request that - smem_epilogue->requestReuse(acw_smem); - swizzleSharedMemory(smem_epilogue, params); scheduler_utils::BoundedDirectionalTransformPropagator::forward( mma_result, From 998e194ab6d12492ed74b12492b917a9f7b8032d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 12:41:43 -0400 Subject: [PATCH 10/30] Almost guarantee reuse before assuming it --- csrc/scheduler/matmul.cpp | 7 ++++--- csrc/scheduler/mma_utils.cpp | 38 ++++++++++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 1fafa8e9a42..951d61466e1 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -510,9 +510,10 @@ void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) { shared_mem_tv->setMemoryType(MemoryType::Shared); // The following line allows us to reclaim the memory allocated to - // shared_mem_tv and reuse it for the epilogue, introducing one block sync. - // This is not done by default as we do not insert new syncs unless requested - // to do so. + // shared_mem_tv and reuse it for the epilogue, introducing one block sync if + // needed. This is not done by default as we do not insert new syncs unless + // requested to do so. If smem is not used for the epilogue, this call will + // have no effect. shared_mem_tv->promoteReuse(); mma_utils::orderTiledConcreteIdAsRoot(shared_mem_tv); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 6ab705edf14..d1c21746370 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -50,8 +50,42 @@ bool generateSharedMemoryEpilogueHeuristics( dataTypeSize(data_types[2]); const size_t total_without_smem_epilogue = smem_a + smem_b; - // Note that we reclaim smem_a and smem_b before allocating smem_c - const size_t total_with_smem_epilogue = std::max(smem_a + smem_b, smem_c); + + // smem_a and smem_b are guaranteed to be re-used for smem_c as long as: + // - they are marked for re-use using promoteReuse + // - they are not aliased by another tensor whose lifetime extends past the + // start of smem_epilogue's. + // - their lifetimes do not overlap smem_epilogue + // + // We guarantee the first condition by calling tv->promoteReuse() in + // scheduleProlog. + // + // The second condition would only be the case if another smem tensor had the + // same indexing and its lifetime did not overlap. This scheduler only uses + // smem for these three arrays, so the only candidate for aliasing is C. If C + // aliases either A or B, the following expression is still valid. + // + // The third condition is satisfied in the simple cases where the inputs to + // the matmul have only this use. However, it could be violated if a or b has + // other uses that get ordered after the matmul; for example when computing + // matmul(A, B) + A for square matrices A and B. In that case, the smem tensor + // resulting from A->cacheAfter() will be used in both the matmul as well as + // the addition that occurs in the epilogue, extending the lifetime such that + // it violates the third condition above. In order to avoid errors in these + // cases, we check that there is no re-use when there is more than one use of + // either a or b. If there are multiple uses we might wind up re-using memory, + // but in that case the calculation below will be overly conservative. + + // TODO: place this logic somewhere up the call stack + /*const auto roles_map = roles_map_opt.getData(); + TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); + TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); + bool smem_reuse_guaranteed = a->uses().size() == 1 && b->uses().size() == 1; + */ + bool smem_reuse_guaranteed = true; + const size_t total_with_smem_epilogue = smem_reuse_guaranteed + ? std::max(smem_a + smem_b, smem_c) + : smem_a + smem_b + smem_c; // shortcut where occupancy change is ignored. if (ignore_occupancy_drop) { From 8ea47ccbcb6e19ba6306d149f06690404ec4dc0a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 23 Aug 2023 09:47:25 -0400 Subject: [PATCH 11/30] Remove old RequestReuse test --- test/test_smem_reuse.cpp | 70 ---------------------------------------- 1 file changed, 70 deletions(-) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 485a570066c..e891566f70f 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -411,74 +411,4 @@ TEST_F(SmemReuseTest, MultiplePromoteReuse) { } } -// Same as NeedsReorderedPush but C requests to reuse A instead of pre-existing -// sync -TEST_F(SmemReuseTest, RequestReuse) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - int64_t H = 5; - - auto tv0 = full( - {IrBuilder::create(H)}, - fusion->oneVal(), - DataType::Float); // pos = a. A = tv0 - tv0->setMemoryType(MemoryType::Shared); - - auto tv1 = - pad(tv0, {fusion->zeroVal(), fusion->oneVal()}); // pos = b. B = tv1 - tv1->setMemoryType(MemoryType::Shared); - - auto tv2 = mul(tv1, tv1); // pos = c - - auto tv3 = sum(tv2, {0}); // gap between b and c. Can parallelize to sync - - auto tv4 = broadcast(tv3, {true}); - auto tv5 = mul(tv4, tv1); // pos = d. C = tv5 - tv5->setMemoryType(MemoryType::Shared); - - auto tv6 = add(tv1, tv1); // pos = e - fusion->addOutput(tv6); - - auto tv7 = neg(tv5); // pos = f - fusion->addOutput(tv7); - - { // This should not re-use memory - GpuLower gpulw(fusion.get()); - - ExpressionEvaluator ee; - std::unordered_set addresses; - int64_t smem_usage = 0; - for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { - EXPECT_NE(alloc->address(), nullptr); - auto addr = ee.evaluate(alloc->address()).as(); - TORCH_CHECK( - addresses.insert(addr).second, - "Smem addresses should not be re-used"); - auto size = ee.evaluate(alloc->size()).as() * - dataTypeSize(alloc->buffer()->dtype()); - smem_usage = std::max(smem_usage, addr + size); - } - EXPECT_EQ( - smem_usage, alignInt(alignInt((H + 1) * 4) + (H + 1) * 4) + H * 4); - } - - { // Request that we re-use the allocation for tv0. This should place a - // __syncthreads() just before tv5 is written. - tv0->promoteReuse(); - - GpuLower gpulw(fusion.get()); - ExpressionEvaluator ee; - int64_t smem_usage = 0; - for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { - EXPECT_NE(alloc->address(), nullptr); - auto addr = ee.evaluate(alloc->address()).as(); - auto size = ee.evaluate(alloc->size()).as() * - dataTypeSize(alloc->buffer()->dtype()); - smem_usage = std::max(smem_usage, addr + size); - } - EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + (H + 1) * 4); - } -} - } // namespace nvfuser From 57a3bab852667ff7eee58dd4f71eb33cee00a75d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 23 Aug 2023 10:09:25 -0400 Subject: [PATCH 12/30] Move reuse guarantee to getMatmulHeuristics --- csrc/scheduler/matmul_utils.cpp | 32 ++++++++++++++++++++++++- csrc/scheduler/mma_utils.cpp | 41 ++++++--------------------------- csrc/scheduler/mma_utils.h | 3 ++- 3 files changed, 40 insertions(+), 36 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 99793578f83..580b4a52ff8 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -416,10 +416,40 @@ std::shared_ptr getMatmulHeuristics( const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion); TORCH_INTERNAL_ASSERT( roles_map_opt.isValid(), "Tensor roles map in mma is not valid."); + + // smem_a and smem_b are guaranteed to be re-used for smem_c as long as: + // - they are marked for re-use using promoteReuse + // - they are not aliased by another tensor whose lifetime extends past the + // start of smem_epilogue's. + // - their lifetimes do not overlap smem_epilogue + // + // We guarantee the first condition by calling tv->promoteReuse() in + // scheduleProlog. + // + // The second condition would only be the case if another smem tensor had the + // same indexing and its lifetime did not overlap. This scheduler only uses + // smem for these three arrays, so the only candidate for aliasing is C. If C + // aliases either A or B, the following expression is still valid. + // + // The third condition is satisfied in the simple cases where the inputs to + // the matmul have only this use. However, it could be violated if a or b has + // other uses that get ordered after the matmul; for example when computing + // matmul(A, B) + A for square matrices A and B. In that case, the smem tensor + // resulting from A->cacheAfter() will be used in both the matmul as well as + // the addition that occurs in the epilogue, extending the lifetime such that + // it violates the third condition above. In order to avoid errors in these + // cases, we check that there is no re-use when there is more than one use of + // either a or b. If there are multiple uses we might wind up re-using memory, + // but in that case the calculation below will be overly conservative. + const auto roles_map = roles_map_opt.getData(); + TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); + TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); + bool smem_reuse_guaranteed = a->uses().size() == 1 && b->uses().size() == 1; params->use_smem_epilogue = mma_utils::generateSharedMemoryEpilogueHeuristics( params->tile_sizes, params->double_buffer_options.smem_double_buffer_stage, - getMmaDataTypes(roles_map_opt.getData())); + getMmaDataTypes(roles_map_opt.getData()), + smem_reuse_guaranteed); if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { debug() << params->toString() << std::endl; diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index d1c21746370..52fb574ae03 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -23,7 +23,8 @@ bool generateSharedMemoryEpilogueHeuristics( const MatMulTileOptions& gemm_tile, const int smem_double_buffer_stage, const MmaDataTypes& data_types, - const bool ignore_occupancy_drop) { + const bool ignore_occupancy_drop, + const bool smem_reuse_guaranteed) { const auto properties = at::cuda::getCurrentDeviceProperties(); const size_t device_smem_limit = properties->sharedMemPerBlockOptin; const size_t shared_memory_overhead = properties->reservedSharedMemPerBlock; @@ -49,40 +50,12 @@ bool generateSharedMemoryEpilogueHeuristics( const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * dataTypeSize(data_types[2]); + // NOTE: we can simply add these sizes since they should be integer multiples + // of 16 bytes, so they will automatically be aligned. const size_t total_without_smem_epilogue = smem_a + smem_b; - - // smem_a and smem_b are guaranteed to be re-used for smem_c as long as: - // - they are marked for re-use using promoteReuse - // - they are not aliased by another tensor whose lifetime extends past the - // start of smem_epilogue's. - // - their lifetimes do not overlap smem_epilogue - // - // We guarantee the first condition by calling tv->promoteReuse() in - // scheduleProlog. - // - // The second condition would only be the case if another smem tensor had the - // same indexing and its lifetime did not overlap. This scheduler only uses - // smem for these three arrays, so the only candidate for aliasing is C. If C - // aliases either A or B, the following expression is still valid. - // - // The third condition is satisfied in the simple cases where the inputs to - // the matmul have only this use. However, it could be violated if a or b has - // other uses that get ordered after the matmul; for example when computing - // matmul(A, B) + A for square matrices A and B. In that case, the smem tensor - // resulting from A->cacheAfter() will be used in both the matmul as well as - // the addition that occurs in the epilogue, extending the lifetime such that - // it violates the third condition above. In order to avoid errors in these - // cases, we check that there is no re-use when there is more than one use of - // either a or b. If there are multiple uses we might wind up re-using memory, - // but in that case the calculation below will be overly conservative. - - // TODO: place this logic somewhere up the call stack - /*const auto roles_map = roles_map_opt.getData(); - TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); - TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); - bool smem_reuse_guaranteed = a->uses().size() == 1 && b->uses().size() == 1; - */ - bool smem_reuse_guaranteed = true; + // Even if we actually do wind up re-claiming smem_a and smem_b, if we + // cannot prove it at this point then we have to assume it will not be + // reclaimed. const size_t total_with_smem_epilogue = smem_reuse_guaranteed ? std::max(smem_a + smem_b, smem_c) : smem_a + smem_b + smem_c; diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 773fac1da3b..f8a84562031 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -305,7 +305,8 @@ TORCH_CUDA_CU_API bool generateSharedMemoryEpilogueHeuristics( const MatMulTileOptions& gemm_tile, const int smem_double_buffer_stage, const MmaDataTypes& data_types, - const bool ignore_occupancy_drop = false); + const bool ignore_occupancy_drop = false, + const bool smem_reuse_guaranteed = false); } // namespace mma_utils From d894c4b4be792a63bf289da493b60db3b8e287de Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 23 Aug 2023 10:59:58 -0400 Subject: [PATCH 13/30] Remove blank space --- csrc/scheduler/matmul.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 951d61466e1..05311e5031e 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -975,7 +975,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { if (params.use_smem_epilogue) { smem_epilogue->setMemoryType(MemoryType::Shared); - swizzleSharedMemory(smem_epilogue, params); scheduler_utils::BoundedDirectionalTransformPropagator::forward( mma_result, From d1e986d09e097efd426c079ba518928e9d32e108 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Sun, 27 Aug 2023 20:14:25 -0400 Subject: [PATCH 14/30] Guarantee a and b reuse separately --- csrc/scheduler/matmul_utils.cpp | 6 ++++-- csrc/scheduler/mma_utils.cpp | 10 ++++++---- csrc/scheduler/mma_utils.h | 3 ++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 580b4a52ff8..0ffe8d9ef33 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -444,12 +444,14 @@ std::shared_ptr getMatmulHeuristics( const auto roles_map = roles_map_opt.getData(); TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); - bool smem_reuse_guaranteed = a->uses().size() == 1 && b->uses().size() == 1; + bool smem_a_reuse_guaranteed = a->uses().size() == 1; + bool smem_b_reuse_guaranteed = b->uses().size() == 1; params->use_smem_epilogue = mma_utils::generateSharedMemoryEpilogueHeuristics( params->tile_sizes, params->double_buffer_options.smem_double_buffer_stage, getMmaDataTypes(roles_map_opt.getData()), - smem_reuse_guaranteed); + smem_a_reuse_guaranteed, + smem_b_reuse_guaranteed); if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { debug() << params->toString() << std::endl; diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 52fb574ae03..db3c7ed35be 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -24,7 +24,8 @@ bool generateSharedMemoryEpilogueHeuristics( const int smem_double_buffer_stage, const MmaDataTypes& data_types, const bool ignore_occupancy_drop, - const bool smem_reuse_guaranteed) { + const bool smem_a_reuse_guaranteed, + const bool smem_b_reuse_guaranteed) { const auto properties = at::cuda::getCurrentDeviceProperties(); const size_t device_smem_limit = properties->sharedMemPerBlockOptin; const size_t shared_memory_overhead = properties->reservedSharedMemPerBlock; @@ -56,9 +57,10 @@ bool generateSharedMemoryEpilogueHeuristics( // Even if we actually do wind up re-claiming smem_a and smem_b, if we // cannot prove it at this point then we have to assume it will not be // reclaimed. - const size_t total_with_smem_epilogue = smem_reuse_guaranteed - ? std::max(smem_a + smem_b, smem_c) - : smem_a + smem_b + smem_c; + const size_t total_with_smem_epilogue = std::max( + smem_a + smem_b, + (smem_a_reuse_guaranteed ? 0 : smem_a) + + (smem_b_reuse_guaranteed ? 0 : smem_b) + smem_c); // shortcut where occupancy change is ignored. if (ignore_occupancy_drop) { diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index f8a84562031..f57bd937b07 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -306,7 +306,8 @@ TORCH_CUDA_CU_API bool generateSharedMemoryEpilogueHeuristics( const int smem_double_buffer_stage, const MmaDataTypes& data_types, const bool ignore_occupancy_drop = false, - const bool smem_reuse_guaranteed = false); + const bool smem_a_reuse_guaranteed = false, + const bool smem_b_reuse_guaranteed = false); } // namespace mma_utils From b9db3ac219ce8c5a05cabbddea824f2022c6ebf4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 29 Aug 2023 08:22:38 -0400 Subject: [PATCH 15/30] Separate smem epilogue from promote_reuse param --- csrc/scheduler/matmul.cpp | 4 +++- csrc/scheduler/matmul_heuristic.h | 5 +++++ csrc/scheduler/matmul_utils.cpp | 13 +++++++------ csrc/scheduler/mma_utils.cpp | 31 ++++++++++++++++++++++++------- csrc/scheduler/mma_utils.h | 11 ++++++++--- test/test_gpu_tensorcore.cpp | 13 +++++++------ test/test_matmul_sass.cpp | 2 +- 7 files changed, 55 insertions(+), 24 deletions(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 05311e5031e..3148d22b7ff 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -514,7 +514,9 @@ void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) { // needed. This is not done by default as we do not insert new syncs unless // requested to do so. If smem is not used for the epilogue, this call will // have no effect. - shared_mem_tv->promoteReuse(); + if (params.promote_prologue_smem_reuse) { + shared_mem_tv->promoteReuse(); + } mma_utils::orderTiledConcreteIdAsRoot(shared_mem_tv); diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index 14fea9db17c..f3150abba58 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -94,6 +94,9 @@ class MatmulParams : public HeuristicParams { //! coalesced write to global memory bool use_smem_epilogue = false; + //! Promote reuse of prologue shared memory + bool promote_prologue_smem_reuse = false; + std::string toString() const override { std::stringstream ss; ss << "\n===== Matmul Parameters ========\n" @@ -117,6 +120,8 @@ class MatmulParams : public HeuristicParams { << "\n" << "Grid swizzle factor: " << grid_swizzle_factor << "\n" << "Use shared memory epilogue: " << use_smem_epilogue << "\n" + << "Promote re-use of prologue shared memory: " + << promote_prologue_smem_reuse << "\n" << "====================================\n"; return ss.str(); } diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 0ffe8d9ef33..06ca76669c7 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -446,12 +446,13 @@ std::shared_ptr getMatmulHeuristics( TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); bool smem_a_reuse_guaranteed = a->uses().size() == 1; bool smem_b_reuse_guaranteed = b->uses().size() == 1; - params->use_smem_epilogue = mma_utils::generateSharedMemoryEpilogueHeuristics( - params->tile_sizes, - params->double_buffer_options.smem_double_buffer_stage, - getMmaDataTypes(roles_map_opt.getData()), - smem_a_reuse_guaranteed, - smem_b_reuse_guaranteed); + std::tie(params->use_smem_epilogue, params->promote_prologue_smem_reuse) = + mma_utils::generateSharedMemoryEpilogueHeuristics( + params->tile_sizes, + params->double_buffer_options.smem_double_buffer_stage, + getMmaDataTypes(roles_map_opt.getData()), + smem_a_reuse_guaranteed, + smem_b_reuse_guaranteed); if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { debug() << params->toString() << std::endl; diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index db3c7ed35be..2588710a640 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -19,7 +19,7 @@ namespace nvfuser { namespace mma_utils { -bool generateSharedMemoryEpilogueHeuristics( +std::pair generateSharedMemoryEpilogueHeuristics( const MatMulTileOptions& gemm_tile, const int smem_double_buffer_stage, const MmaDataTypes& data_types, @@ -57,14 +57,20 @@ bool generateSharedMemoryEpilogueHeuristics( // Even if we actually do wind up re-claiming smem_a and smem_b, if we // cannot prove it at this point then we have to assume it will not be // reclaimed. - const size_t total_with_smem_epilogue = std::max( + const size_t total_with_reused_smem_epilogue = std::max( smem_a + smem_b, (smem_a_reuse_guaranteed ? 0 : smem_a) + (smem_b_reuse_guaranteed ? 0 : smem_b) + smem_c); + const size_t total_with_noreuse_smem_epilogue = smem_a + smem_b + smem_c; + // shortcut where occupancy change is ignored. if (ignore_occupancy_drop) { - return shared_memory_available >= total_with_smem_epilogue; + if (shared_memory_available >= total_with_noreuse_smem_epilogue) { + return {true, false}; + } else { + return {shared_memory_available >= total_with_reused_smem_epilogue, true}; + } } // use additional shared memory for epilogue if occupancy is not changed. @@ -74,12 +80,23 @@ bool generateSharedMemoryEpilogueHeuristics( const auto blocks_per_sm_without_smem_epilogue = std::min( shared_memory_available / total_without_smem_epilogue, (size_t)blocks_per_sm_by_register); - const auto blocks_per_sm_with_smem_epilogue = std::min( - shared_memory_available / total_with_smem_epilogue, + const auto blocks_per_sm_with_reused_smem_epilogue = std::min( + shared_memory_available / total_with_reused_smem_epilogue, (size_t)blocks_per_sm_by_register); + const auto blocks_per_sm_with_noreuse_smem_epilogue = std::min( + shared_memory_available / total_with_noreuse_smem_epilogue, + (size_t)blocks_per_sm_by_register); + + // Return whether we should use smem for epilogue, and whether syncing for + // re-use is desired. We avoid the sync if omitting it does not decrease + // occupancy. + auto promote_prologue_smem_reuse = blocks_per_sm_with_reused_smem_epilogue != + blocks_per_sm_with_noreuse_smem_epilogue; - return blocks_per_sm_with_smem_epilogue == - blocks_per_sm_without_smem_epilogue; + return { + blocks_per_sm_with_reused_smem_epilogue == + blocks_per_sm_without_smem_epilogue, + promote_prologue_smem_reuse}; } void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index f57bd937b07..287e48f9b17 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -293,15 +293,20 @@ TORCH_CUDA_CU_API ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion); //! be gathered. TORCH_CUDA_CU_API RolesMapOpt getTensorsRoles(Fusion* fusion); -//! Return whether use shared memory epilogue or not. -//! Returns true if using shared memory epilogue won't cause +//! Return pair of whether use shared memory epilogue or not and whether to +//! reuse shared memory for the prologue at the expense of an additional block +//! sync. +//! Returns true in first position if using shared memory epilogue won't cause //! the decrease of occupancy ratio. The occupancy ratio is //! estimated using register and shared memory usage. //! If ignore_occupancy_drop is set to true, returns true if //! there is enough shared memory to launch the kernel without //! considering the occupancy, useful for debug and validate //! shared memory epilogue implementation. -TORCH_CUDA_CU_API bool generateSharedMemoryEpilogueHeuristics( +//! +//! Returns true in the second position if reusing shared memory for the +//! epilogue does not increase occupancy. +TORCH_CUDA_CU_API std::pair generateSharedMemoryEpilogueHeuristics( const MatMulTileOptions& gemm_tile, const int smem_double_buffer_stage, const MmaDataTypes& data_types, diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index e8f882bceb0..2ea9a5205e4 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -3258,7 +3258,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; - params.use_smem_epilogue = + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, @@ -3326,7 +3326,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; - params.use_smem_epilogue = + std::tie( + params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, @@ -3389,7 +3390,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; - params.use_smem_epilogue = + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, @@ -3938,7 +3939,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; - params.use_smem_epilogue = + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, @@ -4025,7 +4026,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueCast_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 4; - params.use_smem_epilogue = + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, @@ -4112,7 +4113,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueRelu_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 4; - params.use_smem_epilogue = + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, diff --git a/test/test_matmul_sass.cpp b/test/test_matmul_sass.cpp index 36426dfff82..e2b7ea5725a 100644 --- a/test/test_matmul_sass.cpp +++ b/test/test_matmul_sass.cpp @@ -328,7 +328,7 @@ TEST_F(MatmulSASSTest, AmpereModifiersSharedMemoryEpilogue_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); const int smem_double_buffer_stage = 4; const bool ignore_occupancy_drop = true; - const bool use_smem_epilogue = + const auto [use_smem_epilogue, skip_promote_reuse] = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, smem_double_buffer_stage, From 81601358020db1ed8c00c2dfbf862db82b180a69 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 29 Aug 2023 09:05:42 -0400 Subject: [PATCH 16/30] Move roles_map analysis to generateSharedMemoryEpilogueHeuristics tests failing at the moment --- csrc/scheduler/matmul_utils.cpp | 49 +-------------------------- csrc/scheduler/mma_utils.cpp | 59 +++++++++++++++++++++++++++++---- csrc/scheduler/mma_utils.h | 6 ++-- 3 files changed, 56 insertions(+), 58 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 06ca76669c7..78fb16afb09 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -143,23 +143,6 @@ inline bool initCoreHeuristics( return true; } -//! A wrapper to get MMA Tensor data types -//! The order of returned types: INPUT_A, INPUT_B, OUTPUT_D -inline mma_utils::MmaDataTypes getMmaDataTypes( - const std::map>& roles_map) { - auto getMMADataType = [&](MatmulRole role) { - auto entry = roles_map.find(role); - if (entry != roles_map.end() && !entry->second.empty()) { - return entry->second.front()->dtype(); - } - TORCH_INTERNAL_ASSERT(false, "Get MMA Tensor data type failed!"); - }; - const auto a_type = getMMADataType(MatmulRole::INPUT_A); - const auto b_type = getMMADataType(MatmulRole::INPUT_B); - const auto c_type = getMMADataType(MatmulRole::OUTPUT_D); - return mma_utils::MmaDataTypes{a_type, b_type, c_type}; -} - //! A helper for getting problem shape from fusion and runtime info. ProblemShape getProblemShape( Fusion* fusion, @@ -417,42 +400,12 @@ std::shared_ptr getMatmulHeuristics( TORCH_INTERNAL_ASSERT( roles_map_opt.isValid(), "Tensor roles map in mma is not valid."); - // smem_a and smem_b are guaranteed to be re-used for smem_c as long as: - // - they are marked for re-use using promoteReuse - // - they are not aliased by another tensor whose lifetime extends past the - // start of smem_epilogue's. - // - their lifetimes do not overlap smem_epilogue - // - // We guarantee the first condition by calling tv->promoteReuse() in - // scheduleProlog. - // - // The second condition would only be the case if another smem tensor had the - // same indexing and its lifetime did not overlap. This scheduler only uses - // smem for these three arrays, so the only candidate for aliasing is C. If C - // aliases either A or B, the following expression is still valid. - // - // The third condition is satisfied in the simple cases where the inputs to - // the matmul have only this use. However, it could be violated if a or b has - // other uses that get ordered after the matmul; for example when computing - // matmul(A, B) + A for square matrices A and B. In that case, the smem tensor - // resulting from A->cacheAfter() will be used in both the matmul as well as - // the addition that occurs in the epilogue, extending the lifetime such that - // it violates the third condition above. In order to avoid errors in these - // cases, we check that there is no re-use when there is more than one use of - // either a or b. If there are multiple uses we might wind up re-using memory, - // but in that case the calculation below will be overly conservative. const auto roles_map = roles_map_opt.getData(); - TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); - TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); - bool smem_a_reuse_guaranteed = a->uses().size() == 1; - bool smem_b_reuse_guaranteed = b->uses().size() == 1; std::tie(params->use_smem_epilogue, params->promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( params->tile_sizes, params->double_buffer_options.smem_double_buffer_stage, - getMmaDataTypes(roles_map_opt.getData()), - smem_a_reuse_guaranteed, - smem_b_reuse_guaranteed); + roles_map); if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { debug() << params->toString() << std::endl; diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 2588710a640..98d9437bd2f 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -19,13 +19,28 @@ namespace nvfuser { namespace mma_utils { +//! A wrapper to get MMA Tensor data types +//! The order of returned types: INPUT_A, INPUT_B, OUTPUT_D +inline mma_utils::MmaDataTypes getMmaDataTypes( + const std::map>& roles_map) { + auto getMMADataType = [&](MatmulRole role) { + auto entry = roles_map.find(role); + if (entry != roles_map.end() && !entry->second.empty()) { + return entry->second.front()->dtype(); + } + TORCH_INTERNAL_ASSERT(false, "Get MMA Tensor data type failed!"); + }; + const auto a_type = getMMADataType(MatmulRole::INPUT_A); + const auto b_type = getMMADataType(MatmulRole::INPUT_B); + const auto c_type = getMMADataType(MatmulRole::OUTPUT_D); + return mma_utils::MmaDataTypes{a_type, b_type, c_type}; +} + std::pair generateSharedMemoryEpilogueHeuristics( const MatMulTileOptions& gemm_tile, const int smem_double_buffer_stage, - const MmaDataTypes& data_types, - const bool ignore_occupancy_drop, - const bool smem_a_reuse_guaranteed, - const bool smem_b_reuse_guaranteed) { + const RolesMap& roles_map, + const bool ignore_occupancy_drop) { const auto properties = at::cuda::getCurrentDeviceProperties(); const size_t device_smem_limit = properties->sharedMemPerBlockOptin; const size_t shared_memory_overhead = properties->reservedSharedMemPerBlock; @@ -36,6 +51,8 @@ std::pair generateSharedMemoryEpilogueHeuristics( const auto threads_per_block = warp_dims.m * warp_dims.n * warp_dims.k * properties->warpSize; + const auto data_types = getMmaDataTypes(roles_map); + // see scheduleContiguousVectorLoad const int vector_word = 8; const int round_to_factor = warp_dims.m * warp_dims.n * warp_dims.k * @@ -51,9 +68,41 @@ std::pair generateSharedMemoryEpilogueHeuristics( const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * dataTypeSize(data_types[2]); + // smem_a and smem_b are guaranteed to be re-used for smem_c as long as: + // - they are marked for re-use using promoteReuse + // - they are not aliased by another tensor whose lifetime extends past the + // start of smem_epilogue's. + // - their lifetimes do not overlap smem_epilogue + // + // We can guarantee the first condition by calling tv->promoteReuse() in + // scheduleProlog. + // + // The second condition would only be the case if another smem tensor had the + // same indexing and its lifetime did not overlap. This scheduler only uses + // smem for these three arrays, so the only candidate for aliasing is C. If C + // aliases either A or B, the following expression is still valid. + // + // The third condition is satisfied in the simple cases where the inputs to + // the matmul have only this use. However, it could be violated if a or b has + // other uses that get ordered after the matmul; for example when computing + // matmul(A, B) + A for square matrices A and B. In that case, the smem tensor + // resulting from A->cacheAfter() will be used in both the matmul as well as + // the addition that occurs in the epilogue, extending the lifetime such that + // it violates the third condition above. In order to avoid errors in these + // cases, we check that there is no re-use when there is more than one use of + // either a or b. If there are multiple uses we might wind up re-using memory, + // but in that case the calculation below will be overly conservative. + TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); + TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); + bool smem_a_reuse_guaranteed = a->uses().size() == 1; + bool smem_b_reuse_guaranteed = b->uses().size() == 1; + // NOTE: we can simply add these sizes since they should be integer multiples // of 16 bytes, so they will automatically be aligned. + TORCH_CHECK(smem_a % 16 == 0 && smem_b % 16 == 0 && smem_b % 16 == 0); + const size_t total_without_smem_epilogue = smem_a + smem_b; + const size_t total_with_noreuse_smem_epilogue = smem_a + smem_b + smem_c; // Even if we actually do wind up re-claiming smem_a and smem_b, if we // cannot prove it at this point then we have to assume it will not be // reclaimed. @@ -62,8 +111,6 @@ std::pair generateSharedMemoryEpilogueHeuristics( (smem_a_reuse_guaranteed ? 0 : smem_a) + (smem_b_reuse_guaranteed ? 0 : smem_b) + smem_c); - const size_t total_with_noreuse_smem_epilogue = smem_a + smem_b + smem_c; - // shortcut where occupancy change is ignored. if (ignore_occupancy_drop) { if (shared_memory_available >= total_with_noreuse_smem_epilogue) { diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 287e48f9b17..82dcb39b468 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -309,10 +309,8 @@ TORCH_CUDA_CU_API RolesMapOpt getTensorsRoles(Fusion* fusion); TORCH_CUDA_CU_API std::pair generateSharedMemoryEpilogueHeuristics( const MatMulTileOptions& gemm_tile, const int smem_double_buffer_stage, - const MmaDataTypes& data_types, - const bool ignore_occupancy_drop = false, - const bool smem_a_reuse_guaranteed = false, - const bool smem_b_reuse_guaranteed = false); + const RolesMap& roles_map, + bool ignore_occupancy_drop = false); } // namespace mma_utils From 1d546e3d431d3719480b53165b2ba7eead4fc81d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 29 Aug 2023 09:06:26 -0400 Subject: [PATCH 17/30] Check for reuse in AmpereMatmulSmemEpilogue_CUDA --- test/test_gpu_tensorcore.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 2ea9a5205e4..eb0852640ae 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -3992,6 +3992,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { GTEST_SKIP() << "Test conducted without utilizing shared memory epilogue due to the device's constrained shared memory capacity."; } + + auto smem_allocs = fe.kernel()->summary().dynamic_smem_allocations; + TORCH_CHECK(smem_allocs.size() == 3); + if (params.promote_prologue_smem_reuse) { + // Check prologue shared memory re-use + TORCH_CHECK(smem_allocs.at(1)->address()->isZero()); + TORCH_CHECK(smem_allocs.at(2)->address()->isZero()); + } } } From 036777785d2d02e019c679fc7d4966dc1d1dc963 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 29 Aug 2023 10:00:37 -0400 Subject: [PATCH 18/30] Restore old signature as overload for generateSharedMemoryEpilogueHeuristics --- csrc/scheduler/mma_utils.cpp | 87 ++++++++++++++++++++++-------------- csrc/scheduler/mma_utils.h | 10 +++++ 2 files changed, 63 insertions(+), 34 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 98d9437bd2f..4f281791e33 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -39,8 +39,10 @@ inline mma_utils::MmaDataTypes getMmaDataTypes( std::pair generateSharedMemoryEpilogueHeuristics( const MatMulTileOptions& gemm_tile, const int smem_double_buffer_stage, - const RolesMap& roles_map, - const bool ignore_occupancy_drop) { + const MmaDataTypes& data_types, + bool smem_a_reuse_guaranteed, + bool smem_b_reuse_guaranteed, + bool ignore_occupancy_drop) { const auto properties = at::cuda::getCurrentDeviceProperties(); const size_t device_smem_limit = properties->sharedMemPerBlockOptin; const size_t shared_memory_overhead = properties->reservedSharedMemPerBlock; @@ -51,8 +53,6 @@ std::pair generateSharedMemoryEpilogueHeuristics( const auto threads_per_block = warp_dims.m * warp_dims.n * warp_dims.k * properties->warpSize; - const auto data_types = getMmaDataTypes(roles_map); - // see scheduleContiguousVectorLoad const int vector_word = 8; const int round_to_factor = warp_dims.m * warp_dims.n * warp_dims.k * @@ -68,37 +68,11 @@ std::pair generateSharedMemoryEpilogueHeuristics( const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * dataTypeSize(data_types[2]); - // smem_a and smem_b are guaranteed to be re-used for smem_c as long as: - // - they are marked for re-use using promoteReuse - // - they are not aliased by another tensor whose lifetime extends past the - // start of smem_epilogue's. - // - their lifetimes do not overlap smem_epilogue - // - // We can guarantee the first condition by calling tv->promoteReuse() in - // scheduleProlog. - // - // The second condition would only be the case if another smem tensor had the - // same indexing and its lifetime did not overlap. This scheduler only uses - // smem for these three arrays, so the only candidate for aliasing is C. If C - // aliases either A or B, the following expression is still valid. - // - // The third condition is satisfied in the simple cases where the inputs to - // the matmul have only this use. However, it could be violated if a or b has - // other uses that get ordered after the matmul; for example when computing - // matmul(A, B) + A for square matrices A and B. In that case, the smem tensor - // resulting from A->cacheAfter() will be used in both the matmul as well as - // the addition that occurs in the epilogue, extending the lifetime such that - // it violates the third condition above. In order to avoid errors in these - // cases, we check that there is no re-use when there is more than one use of - // either a or b. If there are multiple uses we might wind up re-using memory, - // but in that case the calculation below will be overly conservative. - TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); - TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); - bool smem_a_reuse_guaranteed = a->uses().size() == 1; - bool smem_b_reuse_guaranteed = b->uses().size() == 1; - // NOTE: we can simply add these sizes since they should be integer multiples - // of 16 bytes, so they will automatically be aligned. + // of 16 bytes, so they will automatically be aligned. This may change with + // FP8, in which case the expressions below should be updated to insert + // alignment expressions, using the expected stack ordering in + // StackBasedSharedMemAllocator. TORCH_CHECK(smem_a % 16 == 0 && smem_b % 16 == 0 && smem_b % 16 == 0); const size_t total_without_smem_epilogue = smem_a + smem_b; @@ -146,6 +120,51 @@ std::pair generateSharedMemoryEpilogueHeuristics( promote_prologue_smem_reuse}; } +std::pair generateSharedMemoryEpilogueHeuristics( + const MatMulTileOptions& gemm_tile, + const int smem_double_buffer_stage, + const RolesMap& roles_map, + const bool ignore_occupancy_drop) { + const auto data_types = getMmaDataTypes(roles_map); + + // smem_a and smem_b are guaranteed to be re-used for smem_c as long as: + // - they are marked for re-use using promoteReuse + // - they are not aliased by another tensor whose lifetime extends past the + // start of smem_epilogue's. + // - their lifetimes do not overlap smem_epilogue + // + // We can guarantee the first condition by calling tv->promoteReuse() in + // scheduleProlog. + // + // The second condition would only be the case if another smem tensor had the + // same indexing and its lifetime did not overlap. This scheduler only uses + // smem for these three arrays, so the only candidate for aliasing is C. If C + // aliases either A or B, the following expression is still valid. + // + // The third condition is satisfied in the simple cases where the inputs to + // the matmul have only this use. However, it could be violated if a or b has + // other uses that get ordered after the matmul; for example when computing + // matmul(A, B) + A for square matrices A and B. In that case, the smem tensor + // resulting from A->cacheAfter() will be used in both the matmul as well as + // the addition that occurs in the epilogue, extending the lifetime such that + // it violates the third condition above. In order to avoid errors in these + // cases, we check that there is no re-use when there is more than one use of + // either a or b. If there are multiple uses we might wind up re-using memory, + // but in that case the calculation below will be overly conservative. + TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); + TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); + bool smem_a_reuse_guaranteed = a->uses().size() == 1; + bool smem_b_reuse_guaranteed = b->uses().size() == 1; + + return generateSharedMemoryEpilogueHeuristics( + gemm_tile, + smem_double_buffer_stage, + data_types, + smem_a_reuse_guaranteed, + smem_b_reuse_guaranteed, + ignore_occupancy_drop); +} + void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { // Assumes // [M, N, K] diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 82dcb39b468..3aa89862c9c 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -312,6 +312,16 @@ TORCH_CUDA_CU_API std::pair generateSharedMemoryEpilogueHeuristics( const RolesMap& roles_map, bool ignore_occupancy_drop = false); +//! This version assumes roles_map has been analyzed to determine smem datatypes +//! as well as guarantees about prologue smem reuse. +TORCH_CUDA_CU_API std::pair generateSharedMemoryEpilogueHeuristics( + const MatMulTileOptions& gemm_tile, + const int smem_double_buffer_stage, + const MmaDataTypes& data_types, + bool smem_a_reuse_guaranteed = false, + bool smem_b_reuse_guaranteed = false, + bool ignore_occupancy_drop = false); + } // namespace mma_utils } // namespace nvfuser From 366a5e7d98b472c18126a377418d3446f698f892 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 29 Aug 2023 12:30:38 -0400 Subject: [PATCH 19/30] Fix MatmulSASSTest.AmpereModifiersSharedMemoryEpilogue_CUDA --- test/test_matmul_sass.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/test/test_matmul_sass.cpp b/test/test_matmul_sass.cpp index e2b7ea5725a..92033f2100f 100644 --- a/test/test_matmul_sass.cpp +++ b/test/test_matmul_sass.cpp @@ -46,7 +46,8 @@ sass::Container getSASSFor( int N, int K, const int smem_double_buffer_stage = 4, - const bool use_smem_epilogue = false) { + const bool use_smem_epilogue = false, + const bool promote_prologue_smem_reuse = false) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -73,6 +74,7 @@ sass::Container getSASSFor( params.double_buffer_options.smem_double_buffer_stage = smem_double_buffer_stage; params.use_smem_epilogue = use_smem_epilogue; + params.promote_prologue_smem_reuse = promote_prologue_smem_reuse; scheduleMatmul(&fusion, params); auto inputs = matmulAtInput(M, N, K, layout); @@ -328,7 +330,7 @@ TEST_F(MatmulSASSTest, AmpereModifiersSharedMemoryEpilogue_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); const int smem_double_buffer_stage = 4; const bool ignore_occupancy_drop = true; - const auto [use_smem_epilogue, skip_promote_reuse] = + const auto [use_smem_epilogue, promote_prologue_smem_reuse] = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, smem_double_buffer_stage, @@ -347,9 +349,9 @@ TEST_F(MatmulSASSTest, AmpereModifiersSharedMemoryEpilogue_CUDA) { bool found_LDGDEPBAR = false; bool found_DEPBAR = false; // kAllSupportedMatmulLayout; int BAR_COUNT = 0; - // we have three shared memory barriers in the kernel if - // use_shared_epilogue - const int EXPECTED_BAR_COUNT = 3; + // we have at least three shared memory barriers in the kernel if + // use_shared_epilogue. If promote_prologue_smem_reuse, then 4 + const int EXPECTED_BAR_COUNT = promote_prologue_smem_reuse ? 4 : 3; sass::Container sass; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, @@ -364,7 +366,8 @@ TEST_F(MatmulSASSTest, AmpereModifiersSharedMemoryEpilogue_CUDA) { N, K, smem_double_buffer_stage, - use_smem_epilogue)); + use_smem_epilogue, + promote_prologue_smem_reuse)); for (auto inst : sass.code) { std::visit( [&](auto&& i) { From 6b54a5c99fac1e242c3891d31a5a9447f40327ea Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 29 Aug 2023 13:54:51 -0400 Subject: [PATCH 20/30] Update sameAs and hash --- csrc/scheduler/matmul_heuristic.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index f3150abba58..e20958f81c8 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -137,7 +137,8 @@ class MatmulParams : public HeuristicParams { (nvfuser::hash(mma_macro) << 1) ^ (double_buffer_options.hash() << 2) ^ (nvfuser::hash(tile_sizes) << 3) ^ (std::hash{}(static_cast(cta_order)) << 4) ^ - (std::hash{}(grid_swizzle_factor) << 5); + (std::hash{}(grid_swizzle_factor) << 5) ^ + (use_smem_epilogue << 6) ^ (use_smem_epilogue << 7); return attr_hash; } @@ -155,7 +156,10 @@ class MatmulParams : public HeuristicParams { other_casted->tile_sizes == tile_sizes && other_casted->double_buffer_options == double_buffer_options && other_casted->cta_order == cta_order && - other_casted->grid_swizzle_factor == grid_swizzle_factor; + other_casted->grid_swizzle_factor == grid_swizzle_factor && + other_casted->use_smem_epilogue == use_smem_epilogue && + other_casted->promote_prologue_smem_reuse == + promote_prologue_smem_reuse; } std::shared_ptr clone() const override { From 5cc78837579ccda47af5c695de3fd30992555976 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 29 Aug 2023 19:31:42 -0400 Subject: [PATCH 21/30] Better comment in reuse check in epilogue test --- test/test_gpu_tensorcore.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index eb0852640ae..877b99ca928 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -3997,8 +3997,19 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { TORCH_CHECK(smem_allocs.size() == 3); if (params.promote_prologue_smem_reuse) { // Check prologue shared memory re-use - TORCH_CHECK(smem_allocs.at(1)->address()->isZero()); - TORCH_CHECK(smem_allocs.at(2)->address()->isZero()); + // smem_allocs = {A, B, C} where C is the epilogue buffer + // since A and B have no further uses, we should be able to reuse both of + // them, implying that the address of C is zero. In this case, B will also + // be allocated at address 0 with A stacked above it at position 8192. + EXPECT_EQ(smem_allocs.size(), 3); + EXPECT_EQ( + smem_allocs.at(0)->address()->evaluateInt(), + // Assuming B size times size(dtype) is a multiple of 16 so that this + // address is aligned + smem_allocs.at(1)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(1)->buffer()->dtype())); + EXPECT_EQ(smem_allocs.at(1)->address()->evaluateInt(), 0L); + EXPECT_EQ(smem_allocs.at(2)->address()->evaluateInt(), 0L); } } } From 7cf708f698b543d200a8e1997226d1696d9ec7f1 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 30 Aug 2023 07:43:30 -0400 Subject: [PATCH 22/30] State which scheduler is considered in mma_utils.cpp Co-authored-by: Andrzej Bekas <118676880+drzejan2@users.noreply.github.com> --- csrc/scheduler/mma_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 4f281791e33..e1e2eebe270 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -137,7 +137,7 @@ std::pair generateSharedMemoryEpilogueHeuristics( // scheduleProlog. // // The second condition would only be the case if another smem tensor had the - // same indexing and its lifetime did not overlap. This scheduler only uses + // same indexing and its lifetime did not overlap. Matmul scheduler only uses // smem for these three arrays, so the only candidate for aliasing is C. If C // aliases either A or B, the following expression is still valid. // From cae968319a7b7512960603fdb78503672294ad46 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Tue, 29 Aug 2023 09:01:19 -0700 Subject: [PATCH 23/30] Add a new way to hold structs in `PolymorphicValue` (#791) See note `[Struct Support in PolymorphicValue]` for description, and the new test `PolymorphicValueTest.Struct` for examples. --- .lintrunner.toml | 4 +- csrc/ir/builder.cpp | 15 ---- csrc/ir/builder.h | 17 ++++- csrc/polymorphic_value.h | 45 ++++++++++++ csrc/struct.inl | 125 ++++++++++++++++++++++++++++++++ csrc/type.cpp | 6 +- csrc/type.h | 19 ++++- test/test_evaluator.cpp | 3 +- test/test_polymorphic_value.cpp | 96 ++++++++++++++++++++++++ 9 files changed, 309 insertions(+), 21 deletions(-) create mode 100644 csrc/struct.inl diff --git a/.lintrunner.toml b/.lintrunner.toml index 4efdbb7b33d..5d7536a1289 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -26,6 +26,7 @@ include_patterns = [ '**/*.h', '**/*.cpp', '**/*.cu', + '**/*.inl', ] exclude_patterns = [ 'third_party/**', @@ -50,9 +51,10 @@ is_formatter = true [[linter]] code = 'CLANGTIDY' include_patterns = [ + '**/*.h', '**/*.cpp', '**/*.cu', - '**/*.h', + '**/*.inl', ] exclude_patterns = [ 'csrc/serde/fusion_cache_generated.h', diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index 2a1501348ea..abb949edaf4 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -258,21 +258,6 @@ Val* IrBuilder::metadataExpr(TensorView* tv) { return tv->fusion()->metadataOf(tv); } -Val* IrBuilder::structExpr( - const std::vector>& fields, - std::string name) { - std::vector field_infos; - field_infos.reserve(fields.size()); - for (auto& field : fields) { - field_infos.emplace_back(StructType::FieldInfo{ - field.first, std::make_shared(field.second->dtype()), true}); - } - DataType dtype = StructType::make(std::move(field_infos), std::move(name)); - auto out = newScalar(dtype); - create(out, fields); - return out; -} - Val* SimplifyingIrBuilder::negExpr(Val* val) { if (val->isZeroInt()) { return val->container()->zeroVal(val->dtype()); diff --git a/csrc/ir/builder.h b/csrc/ir/builder.h index 8dbc6478e60..5bc0994a086 100644 --- a/csrc/ir/builder.h +++ b/csrc/ir/builder.h @@ -126,9 +126,24 @@ class TORCH_CUDA_CU_API IrBuilder { } } + template static Val* structExpr( const std::vector>& fields, - std::string name = ""); + std::string name = "") { + std::vector field_infos; + field_infos.reserve(fields.size()); + for (auto& field : fields) { + field_infos.emplace_back(StructType::FieldInfo{ + field.first, + std::make_shared(field.second->dtype()), + true}); + } + DataType dtype = + StructType::make(std::move(field_infos), std::move(name)); + auto out = newScalar(dtype); + create(out, fields); + return out; + } static Val* newScalar(DataType dtype); diff --git a/csrc/polymorphic_value.h b/csrc/polymorphic_value.h index 7cc227abff7..95e6e7a2354 100644 --- a/csrc/polymorphic_value.h +++ b/csrc/polymorphic_value.h @@ -238,8 +238,51 @@ inline std::ostream& operator<<(std::ostream& os, const Pointer& ptr) { return os; } +struct Struct; +class Accessor; +struct StructType; + +// See Note [Struct Support in PolymorphicValue] for documentation. +class StructHandle { + std::shared_ptr struct_ptr_; + + public: + StructHandle(std::shared_ptr struct_ptr) + : struct_ptr_(std::move(struct_ptr)) {} + StructHandle& operator=(std::shared_ptr struct_ptr) { + struct_ptr_ = std::move(struct_ptr); + return *this; + } + + StructHandle(const StructHandle& other) = default; + StructHandle(StructHandle&& other) = default; + StructHandle& operator=(const StructHandle& other) = default; + StructHandle& operator=(StructHandle&& other) = default; + + template + bool is() const { + return std::dynamic_pointer_cast(struct_ptr_) != nullptr; + } + + template + inline T& as() const { + return *std::dynamic_pointer_cast(struct_ptr_); + } + + inline StructType type() const; + + template + inline std::enable_if_t, Ret&> operator->*( + Ret Class::*member) const { + return as().*member; + } + + inline Accessor operator->*(const std::string& key) const; +}; + using PolymorphicValue = dynamic_type::DynamicType< dynamic_type::Containers, + StructHandle, Pointer, Opaque, at::Tensor, @@ -377,3 +420,5 @@ inline PolymorphicValue toTensor( } // namespace PolymorphicValue_functions } // namespace nvfuser + +#include diff --git a/csrc/struct.inl b/csrc/struct.inl new file mode 100644 index 00000000000..c69c1de712c --- /dev/null +++ b/csrc/struct.inl @@ -0,0 +1,125 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include + +namespace nvfuser { + +// Note [Struct Support in PolymorphicValue] +// +// PolymorphicValue supports structs, which is just a list of named fields. The +// most straightforward way to support structs is to use a map from field name +// to value, something like: +// template +// using Struct = std::unordered_map; +// using PolymorphicValue = DynamicType, ...>; +// However, the performance of this approach is not ideal. So instead of making +// the struct support truely dynamic fields by using a map, we decide to make it +// semi-dynamic: each struct type in nvFuser must be backed by a real struct in +// C++, which mean, the fields have static storage types. But, on the other +// hand, struct fields can also be accessed dynamically, that is, you can get or +// set a struct field without knowning the actual C++ struct and the type of the +// field. Instead, by using solely the string name of the field, you shall be +// able to access fields as a PolymorphicValue. For example, if your struct is +// defined as: +// struct A { int64_t x; double y; }; +// PolymorphicValue v = some struct of type A; +// Then you can access the fields statically like: +// const int64_t& x = v->*&A::x; +// v->*&A::x = 1; +// Static accesses should be very efficient, as fast as dynamic casts + pointer +// dereferences. However, if you don't have access to the definition of `A`, you +// can still access the fields dynamically: +// PolymorphicValue x = v->*"x"; +// v->*"x" = 1; +// Dynamic accesses are slower than static accesses, because you need to do +// string comparisons to find the field, and do casts between the actual field +// type and PolymorphicValue. This can be slow especially when the struct has +// some fields of containers like std::vector, because you need to do +// the conversion between std::vector and std::vector +// every time you get or set a field. +// +// The implementation of this feature requires a few components working +// together: +// 1. StructType: a data type that describes the name and fields of a struct. +// More importantly, it stores a function that can create an instance of a +// struct without requiring the caller to know the actual struct type. +// 2. Struct: a base class for all structs, which provides the virtual interface +// for accessing fields dynamically, as well as an interface for getting the +// StructType of the struct. +// 3. StructHandle: a wrapper around Struct, which maintains the ownership of +// struct objects and provides the overloaded ->* operator for accessing +// fields statically and dynamically. StructHandle is a candidate type for +// PolymorphicValue. +// 4. Accessor: a helper class returned by the dynamic ->* operator, which +// provides the overloaded casting to PolymorphicValue and = operator for +// getting and setting fields dynamically. +// +// With the above components, define a struct type that supports dynamic access +// to fields is basically subclassing Struct and implementing the virtual +// methods. Please check the test PolymorphicValueTest.Struct for an example. + +struct Struct { + virtual ~Struct() = default; + + virtual StructType type() const = 0; + virtual std::function getter( + const std::string& key) const = 0; + virtual std::function setter( + const std::string& key) = 0; +}; + +class Accessor { + std::function getter_; + std::function setter_; + + public: + Accessor( + std::function getter, + std::function setter) + : getter_(std::move(getter)), setter_(std::move(setter)) {} + Accessor(const Accessor& value) = default; + Accessor(Accessor&& value) = default; + Accessor& operator=(const Accessor& value) = default; + Accessor& operator=(Accessor&& value) = default; + + inline const Accessor& operator=(const PolymorphicValue& value) const { + setter_(std::move(value)); + return *this; + } + + inline operator PolymorphicValue() const { + return getter_(); + } +}; + +inline Accessor StructHandle::operator->*(const std::string& key) const { + return Accessor(struct_ptr_->getter(key), struct_ptr_->setter(key)); +} + +// If a struct type is only used in kernel and we will never create an instance +// on the host, we can just use this dummy struct as a placeholder for the +// convenience +struct NotImplementedStruct : public Struct { + StructType type() const override; + + std::function getter( + const std::string& key) const override { + TORCH_INTERNAL_ASSERT(false, "Not implemented"); + } + + std::function setter( + const std::string& key) override { + TORCH_INTERNAL_ASSERT(false, "Not implemented"); + } +}; + +} // namespace nvfuser diff --git a/csrc/type.cpp b/csrc/type.cpp index 23522a33a37..eef9c721852 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -17,6 +17,10 @@ namespace nvfuser { +StructType NotImplementedStruct::type() const { + TORCH_INTERNAL_ASSERT(false, "Not implemented"); +} + StructType globalTensorMetaData( const PrimDataType& dtype, size_t dim, @@ -54,7 +58,7 @@ StructType globalTensorMetaData( ArrayType{std::make_shared(DataType::Index), alloc_dim}); alloc_stride_field.used_in_kernel = true; - return StructType::make( + return StructType::make( {data_field, logical_size_field, logical_stride_field, diff --git a/csrc/type.h b/csrc/type.h index a6b06a1a0d5..a0128458660 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -106,6 +106,7 @@ struct PointerType { struct StructType { std::string name; + std::function()> create; struct FieldInfo { std::string name; @@ -115,8 +116,18 @@ struct StructType { std::vector fields; + template static StructType make(std::vector fields, std::string name = "") { - return StructType{.name = std::move(name), .fields = std::move(fields)}; + static_assert( + std::is_base_of::value, + "StructType::make only accepts Struct types"); + return StructType{ + .name = std::move(name), + .create = + []() { + return std::static_pointer_cast(std::make_shared()); + }, + .fields = std::move(fields)}; } inline const DataType& fieldDataType(const std::string& name) const { @@ -204,6 +215,10 @@ bool StructType::operator==(const StructType& other) const { return true; } +inline StructType StructHandle::type() const { + return struct_ptr_->type(); +} + StructType globalTensorMetaData( const PrimDataType& dtype, size_t dim, @@ -406,7 +421,7 @@ inline DataType getDataType(const PolymorphicValue& value) { std::make_shared( getDataType(NVFUSER_MAYBE_STAR value))}); } - dtype = StructType::make(std::move(fields_info)); + dtype = StructType::make(std::move(fields_info)); } } else if constexpr (std::is_same_v) { // For pointers in polymorphic value, we only store the data size of the diff --git a/test/test_evaluator.cpp b/test/test_evaluator.cpp index 5366770c8e3..78a38362c21 100644 --- a/test/test_evaluator.cpp +++ b/test/test_evaluator.cpp @@ -301,7 +301,8 @@ TEST_F(ExprEvalTest, Struct) { auto* a = IrBuilder::create(DataType::Int); auto* b = IrBuilder::create(DataType::Int); - auto struct_ = IrBuilder::structExpr({{"a", a}, {"b", b}}, "test_struct"); + auto struct_ = IrBuilder::structExpr( + {{"a", a}, {"b", b}}, "test_struct"); auto aa = IrBuilder::getAttrExpr(struct_, "a"); auto bb = IrBuilder::getAttrExpr(struct_, "b"); diff --git a/test/test_polymorphic_value.cpp b/test/test_polymorphic_value.cpp index 50480d66c2e..a76035b9290 100644 --- a/test/test_polymorphic_value.cpp +++ b/test/test_polymorphic_value.cpp @@ -12,9 +12,12 @@ #include #include +#include namespace nvfuser { +using dynamic_type::opcheck; + class PolymorphicValueTest : public NVFuserTest {}; TEST_F(PolymorphicValueTest, OpaqueEquality) { @@ -25,4 +28,97 @@ TEST_F(PolymorphicValueTest, OpaqueEquality) { EXPECT_EQ(b, a); } +TEST_F(PolymorphicValueTest, Struct) { + struct A : public Struct { + int64_t x; + double y; + + StructType type() const override { + std::vector fields(2); + fields.at(0) = {"x", std::make_shared(DataType::Int), true}; + fields.at(1) = {"y", std::make_shared(DataType::Double), false}; + return StructType::make(fields, "A"); + } + + std::function getter( + const std::string& key) const override { + if (key == "x") { + return [this]() { return PolymorphicValue(x); }; + } else if (key == "y") { + return [this]() { return PolymorphicValue(y); }; + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid key"); + } + } + + std::function setter( + const std::string& key) override { + if (key == "x") { + return [this](const PolymorphicValue& value) { x = (int64_t)value; }; + } else if (key == "y") { + return [this](const PolymorphicValue& value) { y = (double)value; }; + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid key"); + } + } + }; + + static_assert(opcheck->*opcheck); + static_assert(opcheck->*opcheck); + + // In a "static context", i.e. we know the C++ type of the struct, we can + // use pointer-to-member syntax to access fields. This is the most efficient + // way to access fields. Accessing fields in this way will give references to + // the fields, so the types of fields are also static. + PolymorphicValue a = std::static_pointer_cast(std::make_shared()); + static_assert(std::is_same_v*&A::x), int64_t&>); + static_assert(std::is_same_v*&A::y), double&>); + a->*& A::x = 299792458; + a->*& A::y = 3.1415926; + EXPECT_EQ(a->*&A::x, 299792458); + EXPECT_EQ(a->*&A::y, 3.1415926); + a->*& A::x = 2788; + a->*& A::y = 2.71828; + EXPECT_EQ(a->*&A::x, 2788); + EXPECT_EQ(a->*&A::y, 2.71828); + + StructType type = (a->*&StructHandle::type)(); + EXPECT_EQ(type.name, "A"); + EXPECT_EQ(type.fields.size(), 2); + EXPECT_EQ(type.fields.at(0).name, "x"); + EXPECT_EQ(*type.fields.at(0).type, DataType::Int); + EXPECT_TRUE(type.fields.at(0).used_in_kernel); + EXPECT_EQ(type.fields.at(1).name, "y"); + EXPECT_EQ(*type.fields.at(1).type, DataType::Double); + EXPECT_FALSE(type.fields.at(1).used_in_kernel); + + { + // intentionally create a new scope and define another struct with the same + // name to make sure the previous struct is not accessible + struct A { + int64_t x; + double y; + }; + static_assert(!(opcheck->*opcheck)); + static_assert(!(opcheck->*opcheck)); + + // In a "dynamic context", i.e. we don't know the C++ type of the struct, we + // can use string keys to access fields, and PolymorphicValue for values. + // This is less efficient than the static context because type conversions + // and key checking are required, but it is the only way to access fields if + // we don't know the C++ type of the struct. + PolymorphicValue b = type.create(); + b->*"x" = 2788; + b->*"y" = 2.71828; + EXPECT_EQ((PolymorphicValue)(b->*"x"), 2788); + EXPECT_EQ((PolymorphicValue)(b->*"y"), 2.71828); + b->*"x" = 299792458; + b->*"y" = 3.1415926; + EXPECT_EQ((PolymorphicValue)(b->*"x"), 299792458); + EXPECT_EQ((PolymorphicValue)(b->*"y"), 3.1415926); + + EXPECT_EQ(type, (b->*&StructHandle::type)()); + } +} + } // namespace nvfuser From 2649b45b6128b2aecd03460d228bb7a603da95f2 Mon Sep 17 00:00:00 2001 From: "Wang, Xiao" <24860335+xwang233@users.noreply.github.com> Date: Tue, 29 Aug 2023 13:46:17 -0700 Subject: [PATCH 24/30] Fixed a minor error in tools/compare_codegen.sh (#810) per title --- tools/compare_codegen.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/compare_codegen.sh b/tools/compare_codegen.sh index 3e09b6276c0..c063e10736a 100755 --- a/tools/compare_codegen.sh +++ b/tools/compare_codegen.sh @@ -214,4 +214,5 @@ cleanup set +e # exit status of diff is 1 if there are any mismatches echo -e "\n\nDIFF RESULT:\n" diff -qr -x '*.log' "$outdir/$origcommit" "$outdir/$comparecommit" && echo "No difference found" -echo $? + +exit $? From 7c4ab6845a4930ac5ff574ea3ddfabe6fc793578 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Tue, 29 Aug 2023 18:06:30 -0400 Subject: [PATCH 25/30] Skip trivial Resize, handle resize of Broadcast at concretization. (#800) This updates the root to rfactor propagation in IterType concretization of dynamic fusions. Previously, although we only overwrote Symbolic IterDomains in this step, we still asserted that we could infer an IterType for each I moved that check so that it is only applied when we need to make a change. Additionally, we previously propagated Broadcast-only IterDomains as Symbolic, since we combine with our previous estimate using promoteIterType. As mentioned in a comment, this means Broadcast gets propagated as Symbolic. Instead we now only fall back to promoteIterType when there are multiple input IterTypes to the IterDomain expression. Fixes #798 --- csrc/dynamic_transform.cpp | 34 +++++++++++++++++++++++++------- csrc/ir/nodes.cpp | 18 +++++++++++++++-- test/test_resize.cpp | 40 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 9 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index c82520e8a11..404506ac3f8 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -627,20 +627,40 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Determine the output IterType IterType iter_type = IterType::Symbolic; - for (auto inp_id : ir_utils::filterByType(expr->inputs())) { + const auto input_ids = + ir_utils::filterByType(expr->inputs()).vector(); + for (auto i : c10::irange(input_ids.size())) { + auto inp_id = input_ids.at(i); auto updated_id = maybeMutated(inp_id)->as(); - iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); + TORCH_CHECK( + updated_id == inp_id || !updated_id->isSymbolic(), + "Mutated IterDomains between root and rfactor should not be symbolic"); + if (i == 0) { + // ops::promoteIterType will favor Symbolic if it encounters it + // alongside Broadcast. This is preferable at fusion definition, but + // here we are propagating, and if we only see Broadcast in some + // dimension, then we should not retain Symbolic. To work around this, + // we always overwrite Symbolic with the first concrete IterType we + // encounter. + iter_type = updated_id->getIterType(); + } else { + iter_type = + ops::promoteIterType(iter_type, updated_id->getIterType()); + } } - TORCH_INTERNAL_ASSERT( - iter_type != IterType::Symbolic, - "Failed to concretize an output IterType for expression: ", - expr->toString()); - // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) { if (!out_id->isSymbolic()) { continue; } + + // If out_id is Symbolic, we need to concretize it here. If we did not + // yet determine its IterType, then we've missed our chance. + TORCH_INTERNAL_ASSERT( + iter_type != IterType::Symbolic, + "Failed to concretize an output IterType for expression: ", + expr->toString()); + auto concretized_out_id = IterDomainBuilder(maybeMutated(out_id)->as()) .iter_type(iter_type) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index fa96cff28ba..db38a2a7737 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2783,6 +2783,19 @@ IterDomain* IterDomain::resize( "Expansion factor must be an integer scalar: ", right_expansion->toString()); + if (left_expansion->isConstInt() && right_expansion->isConstInt()) { + auto left = left_expansion->evaluateInt(); + auto right = right_expansion->evaluateInt(); + if (left == 0 && right == 0) { + // This is a trivial resize. Check that we are not changing the IterType, + // then return the input. + TORCH_CHECK( + !iter_type_opt.has_value() || + iter_type_opt.value() == in->getIterType(), + "If IterType is specified in pad with zero expansion then it must match input"); + return in; + } + } TORCH_CHECK( in->getIterType() == IterType::Iteration || in->getIterType() == IterType::Broadcast || @@ -2823,12 +2836,13 @@ IterDomain* IterDomain::resize( if (iter_type_opt.has_value()) { iter_type = iter_type_opt.value(); } else if (left_expansion->isConstInt() && right_expansion->isConstInt()) { + auto left = left_expansion->evaluateInt(); + auto right = right_expansion->evaluateInt(); if (resized_id_size->isConstInt()) { // Means input extent is also known auto out_extent = resized_id_size->evaluateInt(); iter_type = out_extent == 1 ? IterType::Broadcast : IterType::Iteration; - } else if ( - left_expansion->evaluateInt() + right_expansion->evaluateInt() > 1) { + } else if (left + right > 1) { // Input extent is non-negative, so we know out_extent > 1 iter_type = IterType::Iteration; } diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 9baaa8d95f1..b3c58077746 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -483,6 +483,46 @@ TEST_F(ResizeTest, FusionResizePadScheduler4) { __FILE__); } +// Pad a broadcast +// See https://github.com/NVIDIA/Fuser/issues/798 +TEST_F(ResizeTest, FusionResizePadBroadcastInput) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // IterTypes are {Broadcast, Iteration} + auto tv0 = makeConcreteTensor({1, -1}); + fusion->addInput(tv0); + + // trivial pad of broadcast dimension + auto tv1 = + pad(tv0, + {fusion->oneVal(), + fusion->zeroVal(), + fusion->zeroVal(), + fusion->zeroVal()}); + fusion->addOutput(tv1); + + std::vector shape({1, 2}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = at::pad(t0, {1, 0, 0, 0}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {t1}, + __LINE__, + __FILE__); +} + // Trivial cat TEST_F(ResizeTest, FusionResizeCat1) { Fusion fusion; From 50919f1e797ce4c04f810114a0b67c97bd037a60 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Tue, 29 Aug 2023 16:45:32 -0700 Subject: [PATCH 26/30] Fix no-benchmark build (#809) --- CMakeLists.txt | 2 ++ cmake/Dependencies.cmake | 2 -- lib/dynamic_type/CMakeLists.txt | 52 +++++++++++++++++---------------- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c943472a8aa..1c7798238fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,8 @@ else() # PROJECT_IS_TOP_LEVEL find_package(CUDAToolkit REQUIRED) endif() # PROJECT_IS_TOP_LEVEL +add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/lib/dynamic_type) + # TODO: fix MSVC if(NOT MSVC) find_library(LIBNVTOOLSEXT libnvToolsExt.so PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64/) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index ce17fc615fc..343c52387b8 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -12,8 +12,6 @@ set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE) set(INSTALL_GTEST OFF CACHE BOOL "Install gtest." FORCE) set(BUILD_GMOCK ON CACHE BOOL "Build gmock." FORCE) -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../lib/dynamic_type) - # Add googletest subdirectory but make sure our INCLUDE_DIRECTORIES do not bleed into it. # This is because libraries installed into the root conda env (e.g. MKL) add a global /opt/conda/include directory, # and if there is gtest installed in conda, the third_party/googletest/**.cc source files would try to include headers diff --git a/lib/dynamic_type/CMakeLists.txt b/lib/dynamic_type/CMakeLists.txt index 283db2a8993..7cbd565ca74 100644 --- a/lib/dynamic_type/CMakeLists.txt +++ b/lib/dynamic_type/CMakeLists.txt @@ -5,30 +5,32 @@ add_library(dynamic_type INTERFACE) target_include_directories(dynamic_type INTERFACE src) -function(add_test_for_standard std_version) - set(target test_dynamic_type_${std_version}) - add_executable(${target} - test/ForAllTypes.cpp - test/assignment.cpp - test/binary_ops.cpp - test/container.cpp - test/examples.cpp - test/hash.cpp - test/member.cpp - test/move.cpp - test/null.cpp - test/opcheck.cpp - test/print.cpp - test/typing.cpp - test/unary_ops.cpp - ) - target_include_directories(${target} PUBLIC src) - target_link_libraries(${target} PRIVATE gtest_main gmock_main) - set_property(TARGET ${target} PROPERTY CXX_STANDARD ${std_version}) -endfunction() +if(BUILD_TEST) + function(add_test_for_standard std_version) + set(target test_dynamic_type_${std_version}) + add_executable(${target} + test/ForAllTypes.cpp + test/assignment.cpp + test/binary_ops.cpp + test/container.cpp + test/examples.cpp + test/hash.cpp + test/member.cpp + test/move.cpp + test/null.cpp + test/opcheck.cpp + test/print.cpp + test/typing.cpp + test/unary_ops.cpp + ) + target_include_directories(${target} PUBLIC src) + target_link_libraries(${target} PRIVATE gtest_main gmock_main) + set_property(TARGET ${target} PROPERTY CXX_STANDARD ${std_version}) + endfunction() -add_test_for_standard(17) + add_test_for_standard(17) -# add_test_for_standard(20) -# add_test_for_standard(23) -# add_test_for_standard(26) + # add_test_for_standard(20) + # add_test_for_standard(23) + # add_test_for_standard(26) +endif() From 2948ca720bdeb2948a5c979eb305fb8ae09f5830 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 30 Aug 2023 07:41:12 -0400 Subject: [PATCH 27/30] Fix hash --- csrc/scheduler/matmul_heuristic.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index e20958f81c8..36a82c2b90f 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -128,7 +128,8 @@ class MatmulParams : public HeuristicParams { size_t hash() const override { // combine boolean flags for hashing - size_t attr_hash = + size_t attr_hash = (static_cast(promote_prologue_smem_reuse) << 3) | + (static_cast(use_smem_epilogue) << 2) | (static_cast(rotate_ldmatrix_out_of_main_loop) << 1) | (static_cast(async_gmem_load_operands)); @@ -137,8 +138,7 @@ class MatmulParams : public HeuristicParams { (nvfuser::hash(mma_macro) << 1) ^ (double_buffer_options.hash() << 2) ^ (nvfuser::hash(tile_sizes) << 3) ^ (std::hash{}(static_cast(cta_order)) << 4) ^ - (std::hash{}(grid_swizzle_factor) << 5) ^ - (use_smem_epilogue << 6) ^ (use_smem_epilogue << 7); + (std::hash{}(grid_swizzle_factor) << 5); return attr_hash; } From 0a29cc6f30a21a440279c054d6d32a0dc25f7b1d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 30 Aug 2023 07:42:40 -0400 Subject: [PATCH 28/30] Reformat comment --- csrc/scheduler/mma_utils.h | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 3aa89862c9c..2fe875696f6 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -294,18 +294,18 @@ TORCH_CUDA_CU_API ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion); TORCH_CUDA_CU_API RolesMapOpt getTensorsRoles(Fusion* fusion); //! Return pair of whether use shared memory epilogue or not and whether to -//! reuse shared memory for the prologue at the expense of an additional block -//! sync. -//! Returns true in first position if using shared memory epilogue won't cause -//! the decrease of occupancy ratio. The occupancy ratio is -//! estimated using register and shared memory usage. -//! If ignore_occupancy_drop is set to true, returns true if -//! there is enough shared memory to launch the kernel without -//! considering the occupancy, useful for debug and validate -//! shared memory epilogue implementation. +//! reuse shared memory for the prologue at the expense of an additional block +//! sync. +//! +//! Returns true in first position if using shared memory epilogue won't cause +//! the decrease of occupancy ratio. The occupancy ratio is estimated using +//! register and shared memory usage. If ignore_occupancy_drop is set to true, +//! returns true if there is enough shared memory to launch the kernel without +//! considering the occupancy, useful for debug and validate shared memory +//! epilogue implementation. //! //! Returns true in the second position if reusing shared memory for the -//! epilogue does not increase occupancy. +//! epilogue does not increase occupancy. TORCH_CUDA_CU_API std::pair generateSharedMemoryEpilogueHeuristics( const MatMulTileOptions& gemm_tile, const int smem_double_buffer_stage, From cdf7c242501865510e9f9cec714c6e630c584696 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 30 Aug 2023 12:36:42 -0400 Subject: [PATCH 29/30] Handle all three cases for smem epilogue in test --- test/test_gpu_tensorcore.cpp | 53 +++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 877b99ca928..ae3d3cfdc13 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -3993,23 +3993,56 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { << "Test conducted without utilizing shared memory epilogue due to the device's constrained shared memory capacity."; } + // Check that smem is allocated as expected. + // There are three cases that are determined by the current device in + // mma_utils::generateSharedMemoryEpilogueHeuristics: + // - !use_smem_epilogue : A + B + // - use_smem_epilogue && !promote_prologue_smem_reuse : A + B + C + // - use_smem_epilogue && promote_prologue_smem_reuse : max(A + B, C) auto smem_allocs = fe.kernel()->summary().dynamic_smem_allocations; - TORCH_CHECK(smem_allocs.size() == 3); - if (params.promote_prologue_smem_reuse) { - // Check prologue shared memory re-use - // smem_allocs = {A, B, C} where C is the epilogue buffer - // since A and B have no further uses, we should be able to reuse both of - // them, implying that the address of C is zero. In this case, B will also - // be allocated at address 0 with A stacked above it at position 8192. - EXPECT_EQ(smem_allocs.size(), 3); + if (params.use_smem_epilogue) { + TORCH_CHECK(smem_allocs.size() == 3); + if (params.promote_prologue_smem_reuse) { + // Check prologue shared memory re-use + // smem_allocs = {A, B, C} where C is the epilogue buffer + // since A and B have no further uses, we should be able to reuse both + // of them, implying that the address of C is zero. In this case, B will + // also be allocated at address 0 with A stacked above it at position + // 8192. + EXPECT_EQ( + smem_allocs.at(0)->address()->evaluateInt(), + // Assuming B numel times size(dtype) is a multiple of 16 so that + // this address is aligned + smem_allocs.at(1)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(1)->buffer()->dtype())); + EXPECT_EQ(smem_allocs.at(1)->address()->evaluateInt(), 0L); + EXPECT_EQ(smem_allocs.at(2)->address()->evaluateInt(), 0L); + } else { + // Prologue shared memory is not re-used. In this case, memory should + // stack in C, B, A order. + EXPECT_EQ( + smem_allocs.at(0)->address()->evaluateInt(), + // Assuming for B and C that numel times size(dtype) is a multiple + // of 16 so that this address is aligned + smem_allocs.at(1)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(1)->buffer()->dtype()) + + smem_allocs.at(2)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(2)->buffer()->dtype())); + EXPECT_EQ( + smem_allocs.at(1)->address()->evaluateInt(), + smem_allocs.at(2)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(2)->buffer()->dtype())); + EXPECT_EQ(smem_allocs.at(2)->address()->evaluateInt(), 0L); + } + } else { + TORCH_CHECK(smem_allocs.size() == 2); EXPECT_EQ( smem_allocs.at(0)->address()->evaluateInt(), - // Assuming B size times size(dtype) is a multiple of 16 so that this + // Assuming B numel times size(dtype) is a multiple of 16 so that this // address is aligned smem_allocs.at(1)->size()->evaluateInt() * dataTypeSize(smem_allocs.at(1)->buffer()->dtype())); EXPECT_EQ(smem_allocs.at(1)->address()->evaluateInt(), 0L); - EXPECT_EQ(smem_allocs.at(2)->address()->evaluateInt(), 0L); } } } From 9d0319a7a18ddbb3bf268485d686cc9cd9c55004 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 31 Aug 2023 08:24:07 -0400 Subject: [PATCH 30/30] Remove !use_smem_epilogue case. Test is skipped in this case anyway --- test/test_gpu_tensorcore.cpp | 67 +++++++++++++++--------------------- 1 file changed, 28 insertions(+), 39 deletions(-) diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index ae3d3cfdc13..6557b825f93 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -3996,53 +3996,42 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { // Check that smem is allocated as expected. // There are three cases that are determined by the current device in // mma_utils::generateSharedMemoryEpilogueHeuristics: - // - !use_smem_epilogue : A + B + // - !use_smem_epilogue : A + B (this test is skipped in this case) // - use_smem_epilogue && !promote_prologue_smem_reuse : A + B + C // - use_smem_epilogue && promote_prologue_smem_reuse : max(A + B, C) auto smem_allocs = fe.kernel()->summary().dynamic_smem_allocations; - if (params.use_smem_epilogue) { - TORCH_CHECK(smem_allocs.size() == 3); - if (params.promote_prologue_smem_reuse) { - // Check prologue shared memory re-use - // smem_allocs = {A, B, C} where C is the epilogue buffer - // since A and B have no further uses, we should be able to reuse both - // of them, implying that the address of C is zero. In this case, B will - // also be allocated at address 0 with A stacked above it at position - // 8192. - EXPECT_EQ( - smem_allocs.at(0)->address()->evaluateInt(), - // Assuming B numel times size(dtype) is a multiple of 16 so that - // this address is aligned - smem_allocs.at(1)->size()->evaluateInt() * - dataTypeSize(smem_allocs.at(1)->buffer()->dtype())); - EXPECT_EQ(smem_allocs.at(1)->address()->evaluateInt(), 0L); - EXPECT_EQ(smem_allocs.at(2)->address()->evaluateInt(), 0L); - } else { - // Prologue shared memory is not re-used. In this case, memory should - // stack in C, B, A order. - EXPECT_EQ( - smem_allocs.at(0)->address()->evaluateInt(), - // Assuming for B and C that numel times size(dtype) is a multiple - // of 16 so that this address is aligned - smem_allocs.at(1)->size()->evaluateInt() * - dataTypeSize(smem_allocs.at(1)->buffer()->dtype()) + - smem_allocs.at(2)->size()->evaluateInt() * - dataTypeSize(smem_allocs.at(2)->buffer()->dtype())); - EXPECT_EQ( - smem_allocs.at(1)->address()->evaluateInt(), - smem_allocs.at(2)->size()->evaluateInt() * - dataTypeSize(smem_allocs.at(2)->buffer()->dtype())); - EXPECT_EQ(smem_allocs.at(2)->address()->evaluateInt(), 0L); - } - } else { - TORCH_CHECK(smem_allocs.size() == 2); + TORCH_CHECK(smem_allocs.size() == 3); + if (params.promote_prologue_smem_reuse) { + // Check prologue shared memory re-use + // smem_allocs = {A, B, C} where C is the epilogue buffer + // since A and B have no further uses, we should be able to reuse both + // of them, implying that the address of C is zero. In this case, B will + // also be allocated at address 0 with A stacked above it at position + // 8192. EXPECT_EQ( smem_allocs.at(0)->address()->evaluateInt(), - // Assuming B numel times size(dtype) is a multiple of 16 so that this - // address is aligned + // Assuming B numel times size(dtype) is a multiple of 16 so that + // this address is aligned smem_allocs.at(1)->size()->evaluateInt() * dataTypeSize(smem_allocs.at(1)->buffer()->dtype())); EXPECT_EQ(smem_allocs.at(1)->address()->evaluateInt(), 0L); + EXPECT_EQ(smem_allocs.at(2)->address()->evaluateInt(), 0L); + } else { + // Prologue shared memory is not re-used. In this case, memory should + // stack in C, B, A order. + EXPECT_EQ( + smem_allocs.at(0)->address()->evaluateInt(), + // Assuming for B and C that numel times size(dtype) is a multiple + // of 16 so that this address is aligned + smem_allocs.at(1)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(1)->buffer()->dtype()) + + smem_allocs.at(2)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(2)->buffer()->dtype())); + EXPECT_EQ( + smem_allocs.at(1)->address()->evaluateInt(), + smem_allocs.at(2)->size()->evaluateInt() * + dataTypeSize(smem_allocs.at(2)->buffer()->dtype())); + EXPECT_EQ(smem_allocs.at(2)->address()->evaluateInt(), 0L); } } }