From b7de2b46df7abbcaff64f46f17b1949a578d1732 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 17 Aug 2023 11:30:18 -0400 Subject: [PATCH 01/28] Add TensorView::requestReuse and failing test --- csrc/ir/interface_nodes.h | 27 +++++++++++++++ test/test_smem_reuse.cpp | 73 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 7da5eb10b24..c7cea92401d 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -505,6 +505,24 @@ class TORCH_CUDA_CU_API TensorView : public Val { // ensure consistency. void commitLeafToRFactor(); + //! Get vector of tensors that must have a block sync after their last use and + //! before the definition of this tensor. + const std::vector& getRequestedReusedTensors() const { + return requested_reuse_tvs_; + } + + //! Append to vector of tensors that must have a block sync after their last + //! use and before the definition of this tensor. + void requestReuse(TensorView* tv) { + TORCH_CHECK( + getMemoryType() == MemoryType::Shared, + "Only shared memory tensors re-use memory using requestReuse."); + TORCH_CHECK( + tv->getMemoryType() == MemoryType::Shared, + "Only shared memory tensors can be requested to be re-used."); + requested_reuse_tvs_.push_back(tv); + } + protected: void setDomain(TensorDomain* td) { domain_ = td; @@ -567,6 +585,15 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! transformed when there may actually be a producer tensor that //! may be computed at. unsigned int maybe_max_producer_pos_ = 0; + + //! This holds a collection of shared memory tensors which should have their + //! last use ordered before this tensor in the generated CUDA kernel. When + //! this list is non-empty, it indicates that we should ensure that thread + //! blocks are synchronized such that all threads have performed their last + //! read of any entries in this vector (or any possible aliased tensors of + //! them) before writing to the current tensor. This will then allow us to + //! safely reuse the memory allocated to these tensors. + std::vector requested_reuse_tvs_; }; //! A simple TensorView builder diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 5781dfe754d..a9ecc0b8f7e 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -210,4 +210,77 @@ TEST_F(SmemReuseTest, NeedsReorderedPush) { } } +// Same as NeedsReorderedPush but C requests to reuse A instead of pre-existing +// sync +TEST_F(SmemReuseTest, RequestReuse) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t H_int = 5, W_int = 6; + auto H = IrBuilder::create(H_int); + auto W = IrBuilder::create(W_int); + + auto tv0 = full({H}, fusion->oneVal(), DataType::Float); + auto tv1 = set(tv0); // pos = a. A = tv1 + tv1->setMemoryType(MemoryType::Shared); + + auto tv2 = full({W}, fusion->oneVal(), DataType::Float); // pos = b. B = tv2 + tv2->setMemoryType(MemoryType::Shared); + + auto tv3 = add(tv1, tv1); // pos = c + + auto tv4 = sum(tv3, {0}); // gap between b and c + fusion->addOutput(tv4); + + auto tv5 = broadcast(tv4, {true}); + auto tv6 = mul(tv5, tv3); + + auto tv7 = broadcast(tv6, {true, false}); + auto tv8 = broadcast(tv2, {false, true}); + auto tv9 = mul(tv7, tv8); // pos = d. C = tv9 + tv9->setMemoryType(MemoryType::Shared); + + auto tv10 = add(tv2, tv2); // pos = e + fusion->addOutput(tv10); + + auto tv11 = neg(tv9); // pos = f + fusion->addOutput(tv11); + + { // This should not re-use memory + GpuLower gpulw(fusion.get()); + + ExpressionEvaluator ee; + std::unordered_set addresses; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + auto addr = ee.evaluate(alloc->address()).as(); + TORCH_CHECK( + addresses.insert(addr).second, + "Smem addresses should not be re-used"); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ( + smem_usage, + alignInt(alignInt(H_int * W_int * 4) + W_int * 4) + H_int * 4); + } + + { // Now introduce a block reduction and check that we re-use memory + + tv9->requestReuse(tv1); + + GpuLower gpulw(fusion.get()); + ExpressionEvaluator ee; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + auto addr = ee.evaluate(alloc->address()).as(); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ(smem_usage, alignInt(H_int * 4) + W_int * H_int * 4); + } +} + } // namespace nvfuser From b25313ba522c38090a2fcb3ab0b120d9afd60a5c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 17 Aug 2023 12:01:55 -0400 Subject: [PATCH 02/28] Clone requested reuses when cloning TV --- csrc/tensor_view.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 707e71f36a7..791a0a34cd0 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -252,7 +252,12 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) cpu_scalar_(src->cpu_scalar_), has_swizzle_op_(src->has_swizzle_op_), compute_with_consumers_(ir_cloner->clone(src->compute_with_consumers_)), - compute_with_pos_(src->compute_with_pos_) {} + compute_with_pos_(src->compute_with_pos_) { + requested_reuse_tvs_.reserve(src->requested_reuse_tvs_.size()); + for (auto dep : src->requested_reuse_tvs_) { + requested_reuse_tvs_.push_back(ir_cloner->clone(dep)); + } +} // sets cpu_scalar_ value, which is special handling for CPU based zero-dim // tensors (i.e. CPU Tensors that only have one value). This is only used if From 94bdabe8e4187244283376b31094ab717619fae0 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 17 Aug 2023 14:08:51 -0400 Subject: [PATCH 03/28] First draft of block sync inserter --- csrc/device_lower/pass/alias_memory.cpp | 219 +++++++++++++++++++++--- 1 file changed, 199 insertions(+), 20 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index b1f17c436a9..7f0287f4395 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -31,7 +31,7 @@ // - Find allocations of tensors that can reuse other allocations // (class ReusableAllocationFinder) // - Replace those allocation expressions with their alias fields -// pointing to reused allocations (class AllocationReuseModifier) +// pointing to reused allocations (class AllocationAliasModifier) namespace nvfuser { @@ -571,9 +571,22 @@ struct AllocationInfo { std::unique_ptr outer_live_interval = nullptr; std::unique_ptr outer_subscribed_intevals = nullptr; + // Holds allocations that have alloc_expr as their alias_to + std::vector outer_aliased_by; + + //! Get the last outer read position of either this allocation, or any + //! allocation that is aliased to this allocation. + int getAliasedOuterLastRead() const { + auto last_outer_read = outer_live_interval->lastRead(); + for (auto aliasing : outer_aliased_by) { + last_outer_read = + std::max(last_outer_read, aliasing->outer_live_interval->lastRead()); + } + return last_outer_read; + } }; -class AllocationReuseModifier; +class AllocationAliasModifier; //! Analysis pass to collect the liveness info of local and shared buffers: //! The liveness info is illustrated as follows: @@ -702,8 +715,16 @@ class AllocationInfoMap : private kir::IrVisitor { return all_allocations_; } + std::optional getMaybeAllocInfoFromTV(TensorView* tv) const { + auto alloc_it = tv_to_allocation_map_.find(tv->name()); + if (alloc_it == tv_to_allocation_map_.end()) { + return std::nullopt; + } + return alloc_it->second; + } + protected: - friend AllocationReuseModifier; + friend AllocationAliasModifier; //! When an allocation is registered for replacement, this method should be //! called to update the allocation info so that subsequent lookups behave @@ -894,14 +915,6 @@ class AllocationInfoMap : private kir::IrVisitor { } } - std::optional getMaybeAllocInfoFromTV(TensorView* tv) const { - auto alloc_it = tv_to_allocation_map_.find(tv->name()); - if (alloc_it == tv_to_allocation_map_.end()) { - return std::nullopt; - } - return alloc_it->second; - } - //! Find the loop level of expr that apears in the same scope as //! the reference allocate. Eg. //! @@ -942,6 +955,7 @@ class AllocationInfoMap : private kir::IrVisitor { void setAlias(AllocationInfo* from, AllocationInfo* to) { alias_map_[from] = to; from->alias_to = to->alloc_expr; + to->outer_aliased_by.push_back(from); } private: @@ -1359,17 +1373,17 @@ class ReusableAllocationFinder : private kir::IrVisitor { }; // Replace Allocate exprs as determined by the alias analysis -class AllocationReuseModifier : private kir::ExprMutator { +class AllocationAliasModifier : private kir::ExprMutator { public: static std::vector modify( const std::vector& exprs, AllocationInfoMap& allocation_info_map) { - AllocationReuseModifier modifier(exprs, allocation_info_map); + AllocationAliasModifier modifier(exprs, allocation_info_map); return modifier.exprs_; } private: - AllocationReuseModifier( + AllocationAliasModifier( const std::vector& exprs, AllocationInfoMap& allocation_info_map) : allocation_info_map_(allocation_info_map) { @@ -1608,7 +1622,7 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { StackBasedSharedMemAllocator(const AllocationInfoMap& allocation_info_map) : allocation_info_map_(allocation_info_map) {} - void allocate(std::vector& exprs) { + void allocate(const std::vector& exprs) { recordEvents(); // Traverse expressions: reclaim memory when we pass a blockSync, append to @@ -1788,6 +1802,164 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { } // namespace +// Use allocation info map to find aliases, i.e. allocations that are properly +// sized and parallelized so that they can be re-used without any +// synchronization. +std::vector aliasMemoryAllocations( + const std::vector& exprs, + AllocationInfoMap& allocation_info_map) { + ReusableAllocationFinder::find(exprs, allocation_info_map); + return AllocationAliasModifier::modify(exprs, allocation_info_map); +} + +class RequestedReuseSyncModifier : private kir::ExprMutator { + public: + RequestedReuseSyncModifier( + const std::vector& exprs, + const AllocationInfoMap& allocation_info_map) + : allocation_info_map_(allocation_info_map) { + for (const auto& alloc_info : allocation_info_map.allAllocationInfos()) { + auto tv = alloc_info->alloc_expr->buffer()->as(); + auto tv_first_write = alloc_info->outer_live_interval->firstWrite(); + for (auto dep : tv->getRequestedReusedTensors()) { + auto dep_alloc_info_opt = + allocation_info_map.getMaybeAllocInfoFromTV(dep); + TORCH_CHECK( + dep_alloc_info_opt.has_value(), + "Could not find allocation info for ", + dep->toString(), + " whose memory was requested to be re-used by ", + tv->toString()); + auto dep_alloc_info = dep_alloc_info_opt.value(); + + int dep_last_read = dep_alloc_info->getAliasedOuterLastRead(); + + // TODO: These aliases should not be created in the first place. We + // should modify the alias map instead to make it reuse-request aware. + TORCH_CHECK( + dep_last_read < tv_first_write, + "Requested re-use dependency ", + dep->toString(), + " has last read position on or after first write of ", + tv->toString()); + + sync_intervals_.emplace(dep_last_read, tv_first_write); + } + } + + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Verifying syncs within these intervals:" << std::endl; + for (auto [last_read, first_write] : sync_intervals_) { + debug() << " " << last_read << " - " << first_write << std::endl; + } + } + + if (sync_intervals_.empty()) { + exprs_ = exprs; + } else { + traverseAndInsert(exprs); + } + } + + const std::unordered_set& insertedSyncs() const { + return inserted_syncs_; + } + + const std::vector& modifiedExprs() const { + return exprs_; + } + + private: + using kir::ExprMutator::dispatch; + + void dispatch(Expr* expr) final { + // Skip inserted syncs + if (inserted_syncs_.find(expr) != inserted_syncs_.end()) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Skipping new sync expression " << expr->toString(); + } + kir::ExprMutator::dispatch(expr); + } + + position_ = allocation_info_map_.getScopeMap().getExprPos(expr); + + // If this is an upcoming first write that has not yet been erased, it means + // we have not seen a sync in its interval. So we should insert a BlockSync + // before this expr. + if (upcoming_first_writes_.erase(position_)) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Inserting block sync before position " << position_ + << std::endl; + } + auto new_sync = IrBuilder::create(); + inserted_syncs_.insert(new_sync); + registerInsertBefore(expr, new_sync); + } + + // If we have a sync at this location, we can clear any upcoming first + // writes since they can be considered safe. + if (lower_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Found blocking expression at position " << position_ + << std::endl; + } + upcoming_first_writes_.clear(); + } + + // If this is the lower endpoint of a sync interval, add the upper endpoint + auto range = sync_intervals_.equal_range(position_); + for (auto& it = range.first; it != range.second; ++it) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Found dependency last read at position " << position_ + << " corresponding to first write at " << it->second + << std::endl; + } + upcoming_first_writes_.insert(it->second); + } + + kir::ExprMutator::dispatch(expr); + } + + private: + const AllocationInfoMap& allocation_info_map_; + + // This holds intervals in which we need to ensure a sync exists. All + // threads in a block should arrive at the start of each interval before any + // thread proceeds past the end of the interval. + std::unordered_multimap sync_intervals_; + + // Position within the traversal + int position_ = -1; + + // These are the upper endpoints of needed sync intervals for which we've + // already passed over the lower endpoint. + std::unordered_set upcoming_first_writes_; + + // Holds all new syncs we have inserted + std::unordered_set inserted_syncs_; +}; + +// Insert missing synchronizations in cases where a TensorView is marked as +// needing them. This should be done before allocateSharedMemoryAllocations, +// which uses synchronization points to reclaim unused shared memory. +std::pair, bool> insertRequestedReuseSyncs( + const std::vector& exprs, + AllocationInfoMap& allocation_info_map) { + auto modifier = RequestedReuseSyncModifier(exprs, allocation_info_map); + return {modifier.modifiedExprs(), !modifier.insertedSyncs().empty()}; +} + +// Assign addresses for dynamic shared memory allocations. This re-uses memory +// by reclaiming memory that is unused when encountering a block +// synchronization. +void allocateSharedMemoryAllocations( + const std::vector& exprs, + AllocationInfoMap& allocation_info_map) { + StackBasedSharedMemAllocator(allocation_info_map).allocate(exprs); +} + +// Entry point for all memory re-use including unsynced aliasing as well as +// insertion of requested syncs and memory allocation with reclamation. std::vector reuseMemoryAllocations(const std::vector& exprs) { FUSER_PERF_SCOPE("reuseMemoryAllocations"); @@ -1795,14 +1967,21 @@ std::vector reuseMemoryAllocations(const std::vector& exprs) { AllocationInfoMap allocation_info_map(exprs, debug_print); - ReusableAllocationFinder::find(exprs, allocation_info_map); + const auto aliased_exprs = aliasMemoryAllocations(exprs, allocation_info_map); - auto aliased_exprs = - AllocationReuseModifier::modify(exprs, allocation_info_map); + const auto [synced_exprs, inserted_syncs] = + insertRequestedReuseSyncs(aliased_exprs, allocation_info_map); + + // If we inserted sync expressions, we need to recompute positions of any + // downstream expressions. Rather than try to keep those in sync, we just + // recompute the allocation info map here. + if (inserted_syncs) { + allocation_info_map = AllocationInfoMap(synced_exprs, false); + } - StackBasedSharedMemAllocator(allocation_info_map).allocate(aliased_exprs); + allocateSharedMemoryAllocations(synced_exprs, allocation_info_map); - return aliased_exprs; + return synced_exprs; } } // namespace nvfuser From 04d8b9016c6883bea20782283765907b325fa42a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 18 Aug 2023 08:20:52 -0400 Subject: [PATCH 04/28] Don't count parallelized loops in AllocationInfoMap Previously, we stacked every ForLoop regardless of parallelization. This meant that when the first few dimensions were left of compute at in the whole fusion, even if they were parallelized all tensors would have the same outer live interval. I noticed this for the AmpereMatmulSmemEpilogue_CUDA tests. In that case if you look at the generated CUDA it's clearly not true; the outer for loops do not appear since they are parallelized. This commit fixes this; note that it can affect all reuse analysis including aliasing even of local memory. --- csrc/device_lower/pass/alias_memory.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 7f0287f4395..951cd08441f 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -764,7 +764,13 @@ class AllocationInfoMap : private kir::IrVisitor { void handle(kir::ForLoop* for_loop) final { auto loop_info = scope_map_.getLoopScopeInfo(for_loop); - current_stack_.push_back(loop_info); + if (!for_loop->iter_domain()->isParallelized()) { + // Parallelized loops do not result in for loops in the CUDA kernel, so + // they should not affect liveness analysis. This means that + // current_stack_ will differ from kir::IrVisitor::for_loops_, which will + // actually hold all ForLoops regardless of parallelization. + current_stack_.push_back(loop_info); + } if (debug_printer_) { debug_printer_->pushScope(); } @@ -772,7 +778,9 @@ class AllocationInfoMap : private kir::IrVisitor { if (debug_printer_) { debug_printer_->popScope(); } - current_stack_.pop_back(); + if (!for_loop->iter_domain()->isParallelized()) { + current_stack_.pop_back(); + } } void handle(kir::IfThenElse* ite) final { From e2ef3a58f724bc60eee89f399e82ed494e7b2f08 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 18 Aug 2023 08:24:14 -0400 Subject: [PATCH 05/28] More verbose printing of buffer reuse info --- csrc/device_lower/pass/alias_memory.cpp | 36 ++++++++++++++++++------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 951cd08441f..d76c95e6396 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1657,6 +1657,10 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { // Reclaim memory whenever we pass an Expr that is known to synchronize the // block if (lower_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Block syncing expr found at position " << position_ + << ". Reclaiming memory." << std::endl; + } reclaimMemory(); } @@ -1693,6 +1697,12 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { } void pushAndAssign(AllocationInfo* alloc_info) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + auto alloc = alloc_info->alloc_expr; + debug() << "Pushing allocation for T" << alloc->buffer()->name() + << std::endl; + } + // Assign new address assignNextAddress(alloc_info); @@ -1713,8 +1723,11 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { alloc->setAddress(aligned_address); } if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Allocated address " << alloc->address()->toInlineString() - << " for T" << alloc->buffer()->name() << std::endl; + debug() << "Assigned address " << alloc->address()->toInlineString() + << " for T" << alloc->buffer()->name() << " with size " + << alloc->size()->toInlineString() << " * " + << dataTypeSize(alloc->buffer()->dtype()) << " bytes" + << std::endl; } } @@ -1772,6 +1785,12 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { while (!alloc_stack_.empty()) { auto last_read = lastAliasedRead(alloc_stack_.back()); if (last_read <= position_) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + auto alloc = alloc_stack_.back()->alloc_expr; + debug() << "Popping allocation for T" << alloc->buffer()->name() + << " which has assigned address " + << alloc->address()->toInlineString() << std::endl; + } alloc_stack_.pop_back(); } else { break; @@ -1855,16 +1874,15 @@ class RequestedReuseSyncModifier : private kir::ExprMutator { } } - if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Verifying syncs within these intervals:" << std::endl; - for (auto [last_read, first_write] : sync_intervals_) { - debug() << " " << last_read << " - " << first_write << std::endl; - } - } - if (sync_intervals_.empty()) { exprs_ = exprs; } else { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Verifying syncs within these intervals:" << std::endl; + for (auto [last_read, first_write] : sync_intervals_) { + debug() << " " << last_read << " - " << first_write << std::endl; + } + } traverseAndInsert(exprs); } } From c7043dcfb26beefed78038af4df11028345068ee Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 18 Aug 2023 11:46:02 -0400 Subject: [PATCH 06/28] Clean up old tests. NeedsReorderedPush actually had the lifetimes not quite overlapping. New version is simpler I think. --- test/test_smem_reuse.cpp | 51 +++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index a9ecc0b8f7e..e1c6a7b1e06 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -107,6 +107,7 @@ TEST_F(SmemReuseTest, SimpleCase) { ExpressionEvaluator ee; int64_t smem_usage = 0; for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); auto addr = ee.evaluate(alloc->address()).as(); auto size = ee.evaluate(alloc->size()).as() * dataTypeSize(alloc->buffer()->dtype()); @@ -143,35 +144,31 @@ TEST_F(SmemReuseTest, NeedsReorderedPush) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - int64_t H_int = 5, W_int = 6; - auto H = IrBuilder::create(H_int); - auto W = IrBuilder::create(W_int); + int64_t H = 5; - auto tv0 = full({H}, fusion->oneVal(), DataType::Float); - auto tv1 = set(tv0); // pos = a. A = tv1 - tv1->setMemoryType(MemoryType::Shared); + auto tv0 = full( + {IrBuilder::create(H)}, + fusion->oneVal(), + DataType::Float); // pos = a. A = tv0 + tv0->setMemoryType(MemoryType::Shared); - auto tv2 = full({W}, fusion->oneVal(), DataType::Float); // pos = b. B = tv2 - tv2->setMemoryType(MemoryType::Shared); + auto tv1 = + pad(tv0, {fusion->zeroVal(), fusion->oneVal()}); // pos = b. B = tv1 + tv1->setMemoryType(MemoryType::Shared); - auto tv3 = add(tv1, tv1); // pos = c + auto tv2 = mul(tv1, tv1); // pos = c - auto tv4 = sum(tv3, {0}); // gap between b and c - fusion->addOutput(tv4); - - auto tv5 = broadcast(tv4, {true}); - auto tv6 = mul(tv5, tv3); + auto tv3 = sum(tv2, {0}); // gap between b and c. Can parallelize to sync - auto tv7 = broadcast(tv6, {true, false}); - auto tv8 = broadcast(tv2, {false, true}); - auto tv9 = mul(tv7, tv8); // pos = d. C = tv9 - tv9->setMemoryType(MemoryType::Shared); + auto tv4 = broadcast(tv3, {true}); + auto tv5 = mul(tv4, tv1); // pos = d. C = tv5 + tv5->setMemoryType(MemoryType::Shared); - auto tv10 = add(tv2, tv2); // pos = e - fusion->addOutput(tv10); + auto tv6 = add(tv1, tv1); // pos = e + fusion->addOutput(tv6); - auto tv11 = neg(tv9); // pos = f - fusion->addOutput(tv11); + auto tv7 = neg(tv5); // pos = f + fusion->addOutput(tv7); { // This should not re-use memory GpuLower gpulw(fusion.get()); @@ -180,6 +177,7 @@ TEST_F(SmemReuseTest, NeedsReorderedPush) { std::unordered_set addresses; int64_t smem_usage = 0; for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); auto addr = ee.evaluate(alloc->address()).as(); TORCH_CHECK( addresses.insert(addr).second, @@ -189,24 +187,23 @@ TEST_F(SmemReuseTest, NeedsReorderedPush) { smem_usage = std::max(smem_usage, addr + size); } EXPECT_EQ( - smem_usage, - alignInt(alignInt(H_int * W_int * 4) + W_int * 4) + H_int * 4); + smem_usage, alignInt(alignInt((H + 1) * 4) + (H + 1) * 4) + H * 4); } { // Now introduce a block reduction and check that we re-use memory - - tv4->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(0)->parallelize(ParallelType::TIDx); GpuLower gpulw(fusion.get()); ExpressionEvaluator ee; int64_t smem_usage = 0; for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); auto addr = ee.evaluate(alloc->address()).as(); auto size = ee.evaluate(alloc->size()).as() * dataTypeSize(alloc->buffer()->dtype()); smem_usage = std::max(smem_usage, addr + size); } - EXPECT_EQ(smem_usage, alignInt(H_int * 4) + W_int * H_int * 4); + EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + (H + 1) * 4); } } From 7d08a413d7404a8b1881ec96c74564da5158a800 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 18 Aug 2023 13:16:01 -0400 Subject: [PATCH 07/28] Fix up reuse tests --- test/test_smem_reuse.cpp | 53 +++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index e1c6a7b1e06..58429098cfe 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -213,35 +213,31 @@ TEST_F(SmemReuseTest, RequestReuse) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - int64_t H_int = 5, W_int = 6; - auto H = IrBuilder::create(H_int); - auto W = IrBuilder::create(W_int); - - auto tv0 = full({H}, fusion->oneVal(), DataType::Float); - auto tv1 = set(tv0); // pos = a. A = tv1 - tv1->setMemoryType(MemoryType::Shared); + int64_t H = 5; - auto tv2 = full({W}, fusion->oneVal(), DataType::Float); // pos = b. B = tv2 - tv2->setMemoryType(MemoryType::Shared); + auto tv0 = full( + {IrBuilder::create(H)}, + fusion->oneVal(), + DataType::Float); // pos = a. A = tv0 + tv0->setMemoryType(MemoryType::Shared); - auto tv3 = add(tv1, tv1); // pos = c + auto tv1 = + pad(tv0, {fusion->zeroVal(), fusion->oneVal()}); // pos = b. B = tv1 + tv1->setMemoryType(MemoryType::Shared); - auto tv4 = sum(tv3, {0}); // gap between b and c - fusion->addOutput(tv4); + auto tv2 = mul(tv1, tv1); // pos = c - auto tv5 = broadcast(tv4, {true}); - auto tv6 = mul(tv5, tv3); + auto tv3 = sum(tv2, {0}); // gap between b and c. Can parallelize to sync - auto tv7 = broadcast(tv6, {true, false}); - auto tv8 = broadcast(tv2, {false, true}); - auto tv9 = mul(tv7, tv8); // pos = d. C = tv9 - tv9->setMemoryType(MemoryType::Shared); + auto tv4 = broadcast(tv3, {true}); + auto tv5 = mul(tv4, tv1); // pos = d. C = tv5 + tv5->setMemoryType(MemoryType::Shared); - auto tv10 = add(tv2, tv2); // pos = e - fusion->addOutput(tv10); + auto tv6 = add(tv1, tv1); // pos = e + fusion->addOutput(tv6); - auto tv11 = neg(tv9); // pos = f - fusion->addOutput(tv11); + auto tv7 = neg(tv5); // pos = f + fusion->addOutput(tv7); { // This should not re-use memory GpuLower gpulw(fusion.get()); @@ -250,6 +246,7 @@ TEST_F(SmemReuseTest, RequestReuse) { std::unordered_set addresses; int64_t smem_usage = 0; for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); auto addr = ee.evaluate(alloc->address()).as(); TORCH_CHECK( addresses.insert(addr).second, @@ -259,24 +256,24 @@ TEST_F(SmemReuseTest, RequestReuse) { smem_usage = std::max(smem_usage, addr + size); } EXPECT_EQ( - smem_usage, - alignInt(alignInt(H_int * W_int * 4) + W_int * 4) + H_int * 4); + smem_usage, alignInt(alignInt((H + 1) * 4) + (H + 1) * 4) + H * 4); } - { // Now introduce a block reduction and check that we re-use memory - - tv9->requestReuse(tv1); + { // Request a sync between tv0 and tv5. This will place a __syncthreads just + // before tv5 is written. + tv5->requestReuse(tv0); GpuLower gpulw(fusion.get()); ExpressionEvaluator ee; int64_t smem_usage = 0; for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); auto addr = ee.evaluate(alloc->address()).as(); auto size = ee.evaluate(alloc->size()).as() * dataTypeSize(alloc->buffer()->dtype()); smem_usage = std::max(smem_usage, addr + size); } - EXPECT_EQ(smem_usage, alignInt(H_int * 4) + W_int * H_int * 4); + EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + (H + 1) * 4); } } From bb6e1f44b5e7296bac3aa03769d37904a043f46b Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 18 Aug 2023 13:16:44 -0400 Subject: [PATCH 08/28] Rename and add verification to allocate step --- csrc/device_lower/pass/alias_memory.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index d76c95e6396..fe42723af01 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1978,10 +1978,23 @@ std::pair, bool> insertRequestedReuseSyncs( // Assign addresses for dynamic shared memory allocations. This re-uses memory // by reclaiming memory that is unused when encountering a block // synchronization. -void allocateSharedMemoryAllocations( +void assignSharedMemoryAllocations( const std::vector& exprs, AllocationInfoMap& allocation_info_map) { StackBasedSharedMemAllocator(allocation_info_map).allocate(exprs); + + // Verify that all smem allocations have a non-null address now + for (auto& alloc_info : allocation_info_map.allAllocationInfos()) { + if (alloc_info->mem_type != MemoryType::Shared || alloc_info->alias_to) { + continue; + } + auto alloc = alloc_info->alloc_expr; + TORCH_INTERNAL_ASSERT( + alloc->address(), + "Unaliased allocation for shared memory tensor ", + alloc->buffer()->toString(), + " was not assigned an address"); + } } // Entry point for all memory re-use including unsynced aliasing as well as @@ -2005,7 +2018,7 @@ std::vector reuseMemoryAllocations(const std::vector& exprs) { allocation_info_map = AllocationInfoMap(synced_exprs, false); } - allocateSharedMemoryAllocations(synced_exprs, allocation_info_map); + assignSharedMemoryAllocations(synced_exprs, allocation_info_map); return synced_exprs; } From ffdb9cd8943d6bb6e0669e8ad06117b07d8c7f7a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 07:32:01 -0400 Subject: [PATCH 09/28] Change to promoteReuse interface --- csrc/device_lower/pass/alias_memory.cpp | 92 ++++++++++++++++--------- csrc/ir/interface_nodes.h | 25 ++++++- test/test_smem_reuse.cpp | 10 +-- 3 files changed, 89 insertions(+), 38 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index fe42723af01..d80417a799c 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1633,6 +1633,20 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { void allocate(const std::vector& exprs) { recordEvents(); + for (auto& [pos, fws] : first_write_positions_) { + std::cout << "Position " << pos + << " is first write pos for:" << std::endl; + for (auto& alloc_info : fws) { + std::cout << " " << alloc_info->alloc_expr->buffer()->toString() + << std::endl; + } + } + + for (auto& pos : last_read_positions_) { + std::cout << "Position " << pos + << " is last aliased read pos for some allocs" << std::endl; + } + // Traverse expressions: reclaim memory when we pass a blockSync, append to // waiting_to_push_ when we pass an Allocate handle(exprs); @@ -1646,12 +1660,21 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { void dispatch(Expr* expr) final { position_ = allocation_info_map_.getScopeMap().getExprPos(expr); + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Position " << position_ << std::endl; + } + // Check whether this is a first write position for any allocations auto it = first_write_positions_.find(position_); if (it != first_write_positions_.end()) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Position " << position_ << " is first write for"; + } for (auto alloc_info : it->second) { + debug() << " T" << alloc_info->alloc_expr->buffer()->name(); waiting_to_push_.push_back(alloc_info); } + debug() << std::endl; } // Reclaim memory whenever we pass an Expr that is known to synchronize the @@ -1734,10 +1757,14 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { //! Record first reads and last writes, respecting aliased buffers void recordEvents() { for (auto& alloc_info : allocation_info_map_.allAllocationInfos()) { + std::cout << "alloc info found for " << alloc_info->alloc_expr->toString() + << std::endl; if (alloc_info->mem_type != MemoryType::Shared) { continue; } if (alloc_info->alias_to) { + std::cout << " Allocation aliases" << alloc_info->alias_to->toString() + << std::endl; auto alias_info = allocation_info_map_.getMaybeAllocationInfo(alloc_info->alias_to); TORCH_CHECK(alias_info.has_value()); @@ -1839,38 +1866,38 @@ std::vector aliasMemoryAllocations( return AllocationAliasModifier::modify(exprs, allocation_info_map); } -class RequestedReuseSyncModifier : private kir::ExprMutator { +class PromoteReuseSyncModifier : private kir::ExprMutator { public: - RequestedReuseSyncModifier( + PromoteReuseSyncModifier( const std::vector& exprs, const AllocationInfoMap& allocation_info_map) : allocation_info_map_(allocation_info_map) { + // Find next allocation after last aliased read of all allocations whose + // reuse we need to promote, and record shortest sync intervals with + // subsequent allocations. for (const auto& alloc_info : allocation_info_map.allAllocationInfos()) { - auto tv = alloc_info->alloc_expr->buffer()->as(); - auto tv_first_write = alloc_info->outer_live_interval->firstWrite(); - for (auto dep : tv->getRequestedReusedTensors()) { - auto dep_alloc_info_opt = - allocation_info_map.getMaybeAllocInfoFromTV(dep); - TORCH_CHECK( - dep_alloc_info_opt.has_value(), - "Could not find allocation info for ", - dep->toString(), - " whose memory was requested to be re-used by ", - tv->toString()); - auto dep_alloc_info = dep_alloc_info_opt.value(); - - int dep_last_read = dep_alloc_info->getAliasedOuterLastRead(); - - // TODO: These aliases should not be created in the first place. We - // should modify the alias map instead to make it reuse-request aware. - TORCH_CHECK( - dep_last_read < tv_first_write, - "Requested re-use dependency ", - dep->toString(), - " has last read position on or after first write of ", - tv->toString()); - - sync_intervals_.emplace(dep_last_read, tv_first_write); + if (!alloc_info->alloc_expr->buffer() + ->as() + ->getPromoteReuse()) { + continue; + } + auto last_read = alloc_info->getAliasedOuterLastRead(); + + std::optional nearest_first_write = std::nullopt; + + for (const auto& other : allocation_info_map.allAllocationInfos()) { + auto first_write = other->outer_live_interval->firstWrite(); + if (first_write <= last_read) { + continue; + } + if (!nearest_first_write.has_value() || + first_write < nearest_first_write.value()) { + nearest_first_write = first_write; + } + } + + if (nearest_first_write.has_value()) { + sync_intervals_.emplace(last_read, nearest_first_write.value()); } } @@ -1966,12 +1993,13 @@ class RequestedReuseSyncModifier : private kir::ExprMutator { }; // Insert missing synchronizations in cases where a TensorView is marked as -// needing them. This should be done before allocateSharedMemoryAllocations, -// which uses synchronization points to reclaim unused shared memory. -std::pair, bool> insertRequestedReuseSyncs( +// needing reuse promotion. This should be done before +// allocateSharedMemoryAllocations, which uses synchronization points to reclaim +// unused shared memory. +std::pair, bool> promoteReuseSyncs( const std::vector& exprs, AllocationInfoMap& allocation_info_map) { - auto modifier = RequestedReuseSyncModifier(exprs, allocation_info_map); + auto modifier = PromoteReuseSyncModifier(exprs, allocation_info_map); return {modifier.modifiedExprs(), !modifier.insertedSyncs().empty()}; } @@ -2009,7 +2037,7 @@ std::vector reuseMemoryAllocations(const std::vector& exprs) { const auto aliased_exprs = aliasMemoryAllocations(exprs, allocation_info_map); const auto [synced_exprs, inserted_syncs] = - insertRequestedReuseSyncs(aliased_exprs, allocation_info_map); + promoteReuseSyncs(aliased_exprs, allocation_info_map); // If we inserted sync expressions, we need to recompute positions of any // downstream expressions. Rather than try to keep those in sync, we just diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index c7cea92401d..e04a64e5d37 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -505,14 +505,33 @@ class TORCH_CUDA_CU_API TensorView : public Val { // ensure consistency. void commitLeafToRFactor(); + void promoteReuse(bool b = true) { + promote_reuse_ = b; + } + + bool getPromoteReuse() const { + return promote_reuse_; + } + //! Get vector of tensors that must have a block sync after their last use and //! before the definition of this tensor. const std::vector& getRequestedReusedTensors() const { return requested_reuse_tvs_; } - //! Append to vector of tensors that must have a block sync after their last - //! use and before the definition of this tensor. + //! Request that we reclaim the memory of tv before this tensor is defined. + //! + //! This method influences the shared memory allocator that assigns shared + //! memory addresses at lowering. It ensures that the proper synchronization + //! is present in the kernel to reuse memory and inserts new block + //! synchronizations if necessary. + //! + //! This is only possible if the lifetimes of this tensor and tv do not + //! overlap, i.e. tv is last written or read before this tensor is written. If + //! this is violated and the lifetimes overlap, an exception will be raised at + //! lowering. + //! + //! This method may only be used on shared-memory tensors. void requestReuse(TensorView* tv) { TORCH_CHECK( getMemoryType() == MemoryType::Shared, @@ -594,6 +613,8 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! them) before writing to the current tensor. This will then allow us to //! safely reuse the memory allocated to these tensors. std::vector requested_reuse_tvs_; + + bool promote_reuse_ = false; }; //! A simple TensorView builder diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 58429098cfe..5a9c1849612 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -209,7 +209,7 @@ TEST_F(SmemReuseTest, NeedsReorderedPush) { // Same as NeedsReorderedPush but C requests to reuse A instead of pre-existing // sync -TEST_F(SmemReuseTest, RequestReuse) { +TEST_F(SmemReuseTest, PromoteReuse) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -259,9 +259,9 @@ TEST_F(SmemReuseTest, RequestReuse) { smem_usage, alignInt(alignInt((H + 1) * 4) + (H + 1) * 4) + H * 4); } - { // Request a sync between tv0 and tv5. This will place a __syncthreads just - // before tv5 is written. - tv5->requestReuse(tv0); + { // Request that we re-use the allocation for tv0. This should place a + // __syncthreads() just before tv5 is written. + tv0->promoteReuse(); GpuLower gpulw(fusion.get()); ExpressionEvaluator ee; @@ -277,4 +277,6 @@ TEST_F(SmemReuseTest, RequestReuse) { } } +// TODO: Test involving requested reuse along with automatic aliasing + } // namespace nvfuser From 4da8e92f892be400f3f094b9b914339027996acd Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 10:15:38 -0400 Subject: [PATCH 10/28] Remove prints and fix tv clone --- csrc/device_lower/pass/alias_memory.cpp | 27 ------------- csrc/ir/interface_nodes.h | 52 +++++++------------------ csrc/tensor_view.cpp | 8 +--- 3 files changed, 17 insertions(+), 70 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index d80417a799c..3f1fd918c43 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1633,20 +1633,6 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { void allocate(const std::vector& exprs) { recordEvents(); - for (auto& [pos, fws] : first_write_positions_) { - std::cout << "Position " << pos - << " is first write pos for:" << std::endl; - for (auto& alloc_info : fws) { - std::cout << " " << alloc_info->alloc_expr->buffer()->toString() - << std::endl; - } - } - - for (auto& pos : last_read_positions_) { - std::cout << "Position " << pos - << " is last aliased read pos for some allocs" << std::endl; - } - // Traverse expressions: reclaim memory when we pass a blockSync, append to // waiting_to_push_ when we pass an Allocate handle(exprs); @@ -1660,21 +1646,12 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { void dispatch(Expr* expr) final { position_ = allocation_info_map_.getScopeMap().getExprPos(expr); - if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Position " << position_ << std::endl; - } - // Check whether this is a first write position for any allocations auto it = first_write_positions_.find(position_); if (it != first_write_positions_.end()) { - if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Position " << position_ << " is first write for"; - } for (auto alloc_info : it->second) { - debug() << " T" << alloc_info->alloc_expr->buffer()->name(); waiting_to_push_.push_back(alloc_info); } - debug() << std::endl; } // Reclaim memory whenever we pass an Expr that is known to synchronize the @@ -1757,14 +1734,10 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { //! Record first reads and last writes, respecting aliased buffers void recordEvents() { for (auto& alloc_info : allocation_info_map_.allAllocationInfos()) { - std::cout << "alloc info found for " << alloc_info->alloc_expr->toString() - << std::endl; if (alloc_info->mem_type != MemoryType::Shared) { continue; } if (alloc_info->alias_to) { - std::cout << " Allocation aliases" << alloc_info->alias_to->toString() - << std::endl; auto alias_info = allocation_info_map_.getMaybeAllocationInfo(alloc_info->alias_to); TORCH_CHECK(alias_info.has_value()); diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index e04a64e5d37..0ab47b6eb15 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -505,43 +505,23 @@ class TORCH_CUDA_CU_API TensorView : public Val { // ensure consistency. void commitLeafToRFactor(); + //! Request that we reclaim the memory of this tv before any subsequent + //! tensors are allocated. + //! + //! This method influences the shared memory allocator that assigns shared + //! memory addresses at lowering. It ensures that the proper synchronization + //! is present in the kernel to reuse memory and inserts new block + //! synchronizations if necessary. void promoteReuse(bool b = true) { promote_reuse_ = b; } + //! Returns whether we should insert syncs if needed in order to reuse the + //! memory of this tensor. bool getPromoteReuse() const { return promote_reuse_; } - //! Get vector of tensors that must have a block sync after their last use and - //! before the definition of this tensor. - const std::vector& getRequestedReusedTensors() const { - return requested_reuse_tvs_; - } - - //! Request that we reclaim the memory of tv before this tensor is defined. - //! - //! This method influences the shared memory allocator that assigns shared - //! memory addresses at lowering. It ensures that the proper synchronization - //! is present in the kernel to reuse memory and inserts new block - //! synchronizations if necessary. - //! - //! This is only possible if the lifetimes of this tensor and tv do not - //! overlap, i.e. tv is last written or read before this tensor is written. If - //! this is violated and the lifetimes overlap, an exception will be raised at - //! lowering. - //! - //! This method may only be used on shared-memory tensors. - void requestReuse(TensorView* tv) { - TORCH_CHECK( - getMemoryType() == MemoryType::Shared, - "Only shared memory tensors re-use memory using requestReuse."); - TORCH_CHECK( - tv->getMemoryType() == MemoryType::Shared, - "Only shared memory tensors can be requested to be re-used."); - requested_reuse_tvs_.push_back(tv); - } - protected: void setDomain(TensorDomain* td) { domain_ = td; @@ -605,15 +585,13 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! may be computed at. unsigned int maybe_max_producer_pos_ = 0; - //! This holds a collection of shared memory tensors which should have their - //! last use ordered before this tensor in the generated CUDA kernel. When - //! this list is non-empty, it indicates that we should ensure that thread + //! When this is true, it indicates, if this is a shared memory tensor and + //! there other shared memory tensors whose lifetimes do not overlap and come + //! later than this tensor's lifetime, that we should ensure that thread //! blocks are synchronized such that all threads have performed their last - //! read of any entries in this vector (or any possible aliased tensors of - //! them) before writing to the current tensor. This will then allow us to - //! safely reuse the memory allocated to these tensors. - std::vector requested_reuse_tvs_; - + //! read of this tensor (or any tensors aliasing in) before writing to the + //! current tensor. This will then allow us to safely reuse the memory + //! allocated to this tensor. bool promote_reuse_ = false; }; diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 791a0a34cd0..2a22988d736 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -252,12 +252,8 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) cpu_scalar_(src->cpu_scalar_), has_swizzle_op_(src->has_swizzle_op_), compute_with_consumers_(ir_cloner->clone(src->compute_with_consumers_)), - compute_with_pos_(src->compute_with_pos_) { - requested_reuse_tvs_.reserve(src->requested_reuse_tvs_.size()); - for (auto dep : src->requested_reuse_tvs_) { - requested_reuse_tvs_.push_back(ir_cloner->clone(dep)); - } -} + compute_with_pos_(src->compute_with_pos_), + promote_reuse_(src->promote_reuse_) {} // sets cpu_scalar_ value, which is special handling for CPU based zero-dim // tensors (i.e. CPU Tensors that only have one value). This is only used if From 9c6ec7cee161c9fd89c5ff769fc15b10b8813d20 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 10:23:40 -0400 Subject: [PATCH 11/28] Ignore non-smem tensors --- csrc/device_lower/pass/alias_memory.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 3f1fd918c43..6afb6040187 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1846,12 +1846,11 @@ class PromoteReuseSyncModifier : private kir::ExprMutator { const AllocationInfoMap& allocation_info_map) : allocation_info_map_(allocation_info_map) { // Find next allocation after last aliased read of all allocations whose - // reuse we need to promote, and record shortest sync intervals with + // reuse we need to promote, and record shortest sync intervals relative to // subsequent allocations. for (const auto& alloc_info : allocation_info_map.allAllocationInfos()) { - if (!alloc_info->alloc_expr->buffer() - ->as() - ->getPromoteReuse()) { + auto tv = alloc_info->alloc_expr->buffer()->as(); + if (tv->getMemoryType() != MemoryType::Shared || !tv->getPromoteReuse()) { continue; } auto last_read = alloc_info->getAliasedOuterLastRead(); @@ -1878,7 +1877,7 @@ class PromoteReuseSyncModifier : private kir::ExprMutator { exprs_ = exprs; } else { if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Verifying syncs within these intervals:" << std::endl; + debug() << "Ensuring syncs within these intervals:" << std::endl; for (auto [last_read, first_write] : sync_intervals_) { debug() << " " << last_read << " - " << first_write << std::endl; } From 41799ba2f7a4967aa59464db4b59592dd0c45c5d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 10:26:42 -0400 Subject: [PATCH 12/28] Clean up comment --- test/test_smem_reuse.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 5a9c1849612..2eeb66cf8b8 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -277,6 +277,4 @@ TEST_F(SmemReuseTest, PromoteReuse) { } } -// TODO: Test involving requested reuse along with automatic aliasing - } // namespace nvfuser From f092673f970fe77ccf4d5a3ebac7ea87e3f6164a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 14:05:42 -0400 Subject: [PATCH 13/28] Add a couple new tests --- test/test_smem_reuse.cpp | 156 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 2eeb66cf8b8..4565e036159 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -277,4 +277,160 @@ TEST_F(SmemReuseTest, PromoteReuse) { } } +// TODO: Add tests with +// - one promoted tensor with multiple other tensors that could reuse the +// promoted tensor +// - multiple promoted tensors +// - aliased promoted tensor showing when we can re-use + +// In this example, we promote a single tensor for re-use in a Fusion with two +// downstream tensors that could use its memory. The first downstream tensor is +// not re-used since it is not promoted. +TEST_F(SmemReuseTest, PromoteReuseMultipleDownstream) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t H = 7; + + auto tv0 = + full({IrBuilder::create(H)}, fusion->oneVal(), DataType::Float); + tv0->setMemoryType(MemoryType::Shared); + + auto tv1 = neg(tv0); + + auto tv2 = pad(tv1, {fusion->zeroVal(), fusion->oneVal()}); + tv2->setMemoryType(MemoryType::Shared); + + auto tv3 = neg(tv2); + + auto tv4 = pad(tv3, {fusion->zeroVal(), fusion->oneVal()}); + tv4->setMemoryType(MemoryType::Shared); + + auto tv5 = neg(tv4); + + fusion->addOutput(tv5); + + { // This should not re-use memory + GpuLower gpulw(fusion.get()); + + ExpressionEvaluator ee; + std::unordered_set addresses; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + TORCH_CHECK( + addresses.insert(addr).second, + "Smem addresses should not be re-used"); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ( + smem_usage, alignInt(alignInt((H + 2) * 4) + (H + 1) * 4) + H * 4); + } + + { // Request that we re-use the allocation for tv0. This should place a + // __syncthreads() just before tv2 is written. + tv0->promoteReuse(); + + GpuLower gpulw(fusion.get()); + ExpressionEvaluator ee; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + (H + 2) * 4); + } +} + +// In this example, multiple smem tensors are promoted for re-use. We have +// non-overlapping smem allocations A B C D, and A and C are promoted for reuse. +// Because of that, B re-uses A, then C does not reuse B but stacks on top of +// it. Then D reuses C, and B is reclaimed in the process. Ultimately this means +// the assigned addresses are: +// +// A: 0. Assigned then reclaimed before assignment of B. +// B: alignInt((H + 2) * 4). Stacked on top of C +// C: 0. Assigned along with B in reverse order of last use +// D: 0. B and C are reclaimed before this assignment. +// +// Note that although B was not explicitly requested for re-use, since its +// lifetime ends before D is defined, we try and reclaim it at the same time C +// is reclaimed. They are also ordered on the stack at that point, in descending +// order of last use, meaning B is placed higher on the stack than C. +TEST_F(SmemReuseTest, MultiplePromoteReuse) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t H = 7; + + auto tv0 = + full({IrBuilder::create(H)}, fusion->oneVal(), DataType::Float); + tv0->setMemoryType(MemoryType::Shared); + + auto tv1 = neg(neg(tv0)); + + auto tv2 = pad(tv1, {fusion->zeroVal(), fusion->oneVal()}); + tv2->setMemoryType(MemoryType::Shared); + + auto tv3 = neg(neg(tv2)); + + auto tv4 = pad(tv3, {fusion->zeroVal(), fusion->oneVal()}); + tv4->setMemoryType(MemoryType::Shared); + + auto tv5 = neg(neg(tv4)); + + auto tv6 = pad(tv5, {fusion->zeroVal(), fusion->oneVal()}); + tv6->setMemoryType(MemoryType::Shared); + + auto tv7 = neg(tv6); + + fusion->addOutput(tv7); + + { // This should not re-use memory + GpuLower gpulw(fusion.get()); + + ExpressionEvaluator ee; + std::unordered_set addresses; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + TORCH_CHECK( + addresses.insert(addr).second, + "Smem addresses should not be re-used"); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ( + smem_usage, + alignInt(alignInt(alignInt((H + 3) * 4) + (H + 2) * 4) + (H + 1) * 4) + + H * 4); + } + + { // Request that we re-use A and C + tv0->promoteReuse(); + tv4->promoteReuse(); + + GpuLower gpulw(fusion.get()); + ExpressionEvaluator ee; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + // High water mark has C stacked on top of B + EXPECT_EQ(smem_usage, alignInt((H + 2) * 4) + (H + 1) * 4); + } +} + } // namespace nvfuser From 5ae92e68b5b4dc6987ee317f9ac9902ffc14539d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 14:24:26 -0400 Subject: [PATCH 14/28] Add test with alias shadowing a promoteReuse --- test/test_smem_reuse.cpp | 90 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 4565e036159..67c0cb20f9a 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -433,4 +433,94 @@ TEST_F(SmemReuseTest, MultiplePromoteReuse) { } } +// In this example we initially have two shared memory tensors A and B. We call +// promoteReuse on A so that B reuses its memory. Then we show that when an +// additional tensor C shaped the same as A is added after B, it aliases A, +// preventing the re-use of A for B's allocation. +TEST_F(SmemReuseTest, PromoteReuseAliasShadowed) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t H = 4; + + auto tv0 = + full({IrBuilder::create(H)}, fusion->oneVal(), DataType::Float); + tv0->setMemoryType(MemoryType::Shared); + + auto tv1 = neg(neg(tv0)); + + auto tv2 = pad(tv1, {fusion->zeroVal(), fusion->oneVal()}); + tv2->setMemoryType(MemoryType::Shared); + + auto tv3 = neg(neg(tv2)); + + fusion->addOutput(tv3); + + { // This should not re-use memory + GpuLower gpulw(fusion.get()); + + ExpressionEvaluator ee; + std::unordered_set addresses; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + TORCH_CHECK( + addresses.insert(addr).second, + "Smem addresses should not be re-used"); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + H * 4); + } + + { // Request that we re-use A for B + tv0->promoteReuse(); + + GpuLower gpulw(fusion.get()); + ExpressionEvaluator ee; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + // High water mark has C stacked on top of B + EXPECT_EQ(smem_usage, (H + 1) * 4); + } + + { // Request that we re-use A for B, but add another smem tensor C that will + // alias A, preventing re-use of A for B (but reusing A for C without any + // inserted syncs). This changes the last aliased read time of A to come + // after that of B, so in this case, A will stack on top of B, unlike in the + // case without re-use and without C. + + // pad by negative one to trim back to size H + auto tv4 = pad(tv3, {fusion->zeroVal(), neg(fusion->oneVal())}); + tv4->setMemoryType(MemoryType::Shared); + + auto tv5 = neg(neg(tv4)); + + fusion->addOutput(tv5); + + tv0->promoteReuse(); // will not have any effect + + GpuLower gpulw(fusion.get()); + ExpressionEvaluator ee; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + // B is stacked on A (which is aliased by C) + EXPECT_EQ(smem_usage, alignInt(H * 4) + (H + 1) * 4); + } +} + } // namespace nvfuser From 2c7074b75ed272e123a3d3cf09de057c63a46f60 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 21 Aug 2023 15:21:12 -0400 Subject: [PATCH 15/28] Remove alias test. It was actually not aliasing. It's pretty tough to force aliasing _and_ re-use --- test/test_smem_reuse.cpp | 90 ---------------------------------------- 1 file changed, 90 deletions(-) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 67c0cb20f9a..4565e036159 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -433,94 +433,4 @@ TEST_F(SmemReuseTest, MultiplePromoteReuse) { } } -// In this example we initially have two shared memory tensors A and B. We call -// promoteReuse on A so that B reuses its memory. Then we show that when an -// additional tensor C shaped the same as A is added after B, it aliases A, -// preventing the re-use of A for B's allocation. -TEST_F(SmemReuseTest, PromoteReuseAliasShadowed) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - int64_t H = 4; - - auto tv0 = - full({IrBuilder::create(H)}, fusion->oneVal(), DataType::Float); - tv0->setMemoryType(MemoryType::Shared); - - auto tv1 = neg(neg(tv0)); - - auto tv2 = pad(tv1, {fusion->zeroVal(), fusion->oneVal()}); - tv2->setMemoryType(MemoryType::Shared); - - auto tv3 = neg(neg(tv2)); - - fusion->addOutput(tv3); - - { // This should not re-use memory - GpuLower gpulw(fusion.get()); - - ExpressionEvaluator ee; - std::unordered_set addresses; - int64_t smem_usage = 0; - for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { - EXPECT_NE(alloc->address(), nullptr); - auto addr = ee.evaluate(alloc->address()).as(); - TORCH_CHECK( - addresses.insert(addr).second, - "Smem addresses should not be re-used"); - auto size = ee.evaluate(alloc->size()).as() * - dataTypeSize(alloc->buffer()->dtype()); - smem_usage = std::max(smem_usage, addr + size); - } - EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + H * 4); - } - - { // Request that we re-use A for B - tv0->promoteReuse(); - - GpuLower gpulw(fusion.get()); - ExpressionEvaluator ee; - int64_t smem_usage = 0; - for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { - EXPECT_NE(alloc->address(), nullptr); - auto addr = ee.evaluate(alloc->address()).as(); - auto size = ee.evaluate(alloc->size()).as() * - dataTypeSize(alloc->buffer()->dtype()); - smem_usage = std::max(smem_usage, addr + size); - } - // High water mark has C stacked on top of B - EXPECT_EQ(smem_usage, (H + 1) * 4); - } - - { // Request that we re-use A for B, but add another smem tensor C that will - // alias A, preventing re-use of A for B (but reusing A for C without any - // inserted syncs). This changes the last aliased read time of A to come - // after that of B, so in this case, A will stack on top of B, unlike in the - // case without re-use and without C. - - // pad by negative one to trim back to size H - auto tv4 = pad(tv3, {fusion->zeroVal(), neg(fusion->oneVal())}); - tv4->setMemoryType(MemoryType::Shared); - - auto tv5 = neg(neg(tv4)); - - fusion->addOutput(tv5); - - tv0->promoteReuse(); // will not have any effect - - GpuLower gpulw(fusion.get()); - ExpressionEvaluator ee; - int64_t smem_usage = 0; - for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { - EXPECT_NE(alloc->address(), nullptr); - auto addr = ee.evaluate(alloc->address()).as(); - auto size = ee.evaluate(alloc->size()).as() * - dataTypeSize(alloc->buffer()->dtype()); - smem_usage = std::max(smem_usage, addr + size); - } - // B is stacked on A (which is aliased by C) - EXPECT_EQ(smem_usage, alignInt(H * 4) + (H + 1) * 4); - } -} - } // namespace nvfuser From 505427145e877d2ddf4ca30d666ef622b5674800 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 07:24:27 -0400 Subject: [PATCH 16/28] Use raw pointers instead of std::optional pointers --- csrc/device_lower/pass/alias_memory.cpp | 76 +++++++++++-------------- 1 file changed, 33 insertions(+), 43 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 6afb6040187..2f2722052e5 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -640,11 +640,10 @@ class AllocationInfoMap : private kir::IrVisitor { current_stack_.pop_back(); } - std::optional getMaybeAllocationInfo( - const kir::Allocate* alloc) const { + AllocationInfo* getAllocationInfo(const kir::Allocate* alloc) const { auto it = allocation_info_map_.find(alloc); if (it == allocation_info_map_.end()) { - return std::nullopt; + return nullptr; } return it->second; } @@ -715,10 +714,10 @@ class AllocationInfoMap : private kir::IrVisitor { return all_allocations_; } - std::optional getMaybeAllocInfoFromTV(TensorView* tv) const { + AllocationInfo* getAllocInfoFromTV(TensorView* tv) const { auto alloc_it = tv_to_allocation_map_.find(tv->name()); if (alloc_it == tv_to_allocation_map_.end()) { - return std::nullopt; + return nullptr; } return alloc_it->second; } @@ -876,48 +875,45 @@ class AllocationInfoMap : private kir::IrVisitor { // expr. The current analysis isn't enough to capture // their liveness range. for (auto input_tv : ir_utils::filterByType(expr->inputs())) { - auto maybe_alloc_info = getMaybeAllocInfoFromTV(input_tv); - if (maybe_alloc_info.has_value()) { + auto alloc_info = getAllocInfoFromTV(input_tv); + if (alloc_info) { if (!isSerialBroadcastResolution(input_tv, for_loops_)) { - maybe_alloc_info.value()->inner_live_interval->markRead(expr_pos); + alloc_info->inner_live_interval->markRead(expr_pos); } else { // Disable inner alias info for this buffer, since line number based // analysis is no longer precise enough for inplace sharing // if a serial broadcast is realized. - maybe_alloc_info.value()->can_use_inner_alias = false; + alloc_info->can_use_inner_alias = false; } - auto outer_loop_info = - ascendLoopNestToSameLevelAs(maybe_alloc_info.value()); + auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info); if (outer_loop_info) { - maybe_alloc_info.value()->outer_live_interval->markRead( - outer_loop_info->end_pos); + alloc_info->outer_live_interval->markRead(outer_loop_info->end_pos); } else { // Allocate is inlined in the innermost loop, // so outer live interval is the same as inner. - maybe_alloc_info.value()->outer_live_interval->markRead(expr_pos); + alloc_info->outer_live_interval->markRead(expr_pos); } } } for (auto output_tv : ir_utils::filterByType(expr->outputs())) { - auto maybe_alloc_info = getMaybeAllocInfoFromTV(output_tv); - if (maybe_alloc_info.has_value()) { + auto alloc_info = getAllocInfoFromTV(output_tv); + if (alloc_info) { // Reductions use outputs as read-write parameters, so their // outputs need to be marked as read as well const bool is_read_write = ir_utils::isReductionOp(expr); - maybe_alloc_info.value()->inner_live_interval->markWrite(expr_pos); + alloc_info->inner_live_interval->markWrite(expr_pos); if (is_read_write) { - maybe_alloc_info.value()->inner_live_interval->markRead(expr_pos); + alloc_info->inner_live_interval->markRead(expr_pos); } - auto outer_loop_info = - ascendLoopNestToSameLevelAs(maybe_alloc_info.value()); + auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info); auto write_pos = outer_loop_info ? outer_loop_info->start_pos : expr_pos; - maybe_alloc_info.value()->outer_live_interval->markWrite(write_pos); + alloc_info->outer_live_interval->markWrite(write_pos); if (is_read_write) { auto read_pos = outer_loop_info ? outer_loop_info->end_pos : expr_pos; - maybe_alloc_info.value()->outer_live_interval->markRead(read_pos); + alloc_info->outer_live_interval->markRead(read_pos); } } } @@ -996,14 +992,15 @@ void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { TORCH_INTERNAL_ASSERT(allocation_info_map_ != nullptr); std::string message_header(" \033[1;32m^^^^^ ---Buffer Reuse Info--- "); std::string message_end(" \033[0m\n"); - if (!allocation_info_map_->getMaybeAllocationInfo(alloc).has_value()) { + + auto alloc_info = allocation_info_map_->getAllocationInfo(alloc); + + if (!alloc_info) { // This buffer is not considered for any sharing, either // because of un-supported op or size below threshold. return; } - auto alloc_info = allocation_info_map_->getMaybeAllocationInfo(alloc).value(); - indent() << message_header; if (alloc_info->alias_to) { if (alloc_info->is_inner_alias) { @@ -1012,8 +1009,7 @@ void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { os_ << "(outer) "; } os_ << " alias to alloc at pos " - << allocation_info_map_->getMaybeAllocationInfo(alloc_info->alias_to) - .value() + << allocation_info_map_->getAllocationInfo(alloc_info->alias_to) ->alloc_pos << " "; } else { @@ -1078,17 +1074,14 @@ class ReusableAllocationFinder : private kir::IrVisitor { // Check that if this allocation site is one that // we want to re-use or replace with an alias - auto maybe_alloc_info = - allocation_info_map_.getMaybeAllocationInfo(allocate); - if (maybe_alloc_info.has_value() && - maybe_alloc_info.value()->alias_to == nullptr) { + auto alloc_info = allocation_info_map_.getAllocationInfo(allocate); + if (alloc_info && alloc_info->alias_to == nullptr) { // Try to re-use existing allocates - if (!tryReuseOtherAllocate(maybe_alloc_info.value())) { + if (!tryReuseOtherAllocate(alloc_info)) { // If didn't re-use, should register this // allocate so that future allocates // can re-use this one. - current_visible_buffer_stack_.back()->push_back( - maybe_alloc_info.value()); + current_visible_buffer_stack_.back()->push_back(alloc_info); } } } @@ -1402,14 +1395,11 @@ class AllocationAliasModifier : private kir::ExprMutator { //! Replace an kir::Allocate with a new aliased Allocate void handle(kir::Allocate* allocate) final { - auto maybe_alloc_info = - allocation_info_map_.getMaybeAllocationInfo(allocate); - if (!maybe_alloc_info.has_value()) { + auto alloc_info_from = allocation_info_map_.getAllocationInfo(allocate); + if (!alloc_info_from) { return; } - AllocationInfo* alloc_info_from = maybe_alloc_info.value(); - auto alias_it = allocation_info_map_.getAliasMap().find(alloc_info_from); if (alias_it == allocation_info_map_.getAliasMap().end()) { return; @@ -1739,10 +1729,10 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { } if (alloc_info->alias_to) { auto alias_info = - allocation_info_map_.getMaybeAllocationInfo(alloc_info->alias_to); - TORCH_CHECK(alias_info.has_value()); - auto prev_last_read = lastAliasedRead(alias_info.value()); - last_aliased_read_[alias_info.value()] = std::max( + allocation_info_map_.getAllocationInfo(alloc_info->alias_to); + TORCH_CHECK(alias_info); + auto prev_last_read = lastAliasedRead(alias_info); + last_aliased_read_[alias_info] = std::max( prev_last_read, alloc_info->outer_live_interval->lastRead()); } else { last_aliased_read_[alloc_info.get()] = From c43efb8f041242605d75e9d5facff40c6b3671a8 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 07:51:39 -0400 Subject: [PATCH 17/28] Assert against multi-hop alias --- csrc/device_lower/pass/alias_memory.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 2f2722052e5..27fd0b8ba63 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -957,6 +957,7 @@ class AllocationInfoMap : private kir::IrVisitor { //! Mark the tensor of "from" be an alias of the tensor of "to". void setAlias(AllocationInfo* from, AllocationInfo* to) { + TORCH_CHECK(!to->alias_to, "Multi-hop aliases are not supported"); alias_map_[from] = to; from->alias_to = to->alloc_expr; to->outer_aliased_by.push_back(from); From 7d60ccf23a771d6f0d1d78935aab1d2ab5832d0d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 07:51:59 -0400 Subject: [PATCH 18/28] Fix comment about outer smem aliasing --- csrc/device_lower/pass/alias_memory.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 27fd0b8ba63..3ed4f9253ed 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1190,12 +1190,10 @@ class ReusableAllocationFinder : private kir::IrVisitor { } } - // TODO: - // Outer interval based sharing supports arbitrary re-indexing into - // the same buffer and would require additional syncs if fully - // enabled. - // Need a few more checks to insert syncs if necessary before turning - // on this sharing. + // Outer aliasing of shared memory requires thread block synchronization + // since it could involve arbitrary re-indexing. Instead, we will leave + // this type of re-use to the allocation phase. See + // assignSharedMemoryAllocations and promoteReuseSyncs. if (!inner_aliasing_pass_ && alloc_info->mem_type == MemoryType::Shared) { continue; From a0fce50443eb6932f828d9464194083d4c9876de Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 07:53:15 -0400 Subject: [PATCH 19/28] Style change to assertion --- csrc/device_lower/pass/alias_memory.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 3ed4f9253ed..4f5b6bada0a 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -957,7 +957,14 @@ class AllocationInfoMap : private kir::IrVisitor { //! Mark the tensor of "from" be an alias of the tensor of "to". void setAlias(AllocationInfo* from, AllocationInfo* to) { - TORCH_CHECK(!to->alias_to, "Multi-hop aliases are not supported"); + TORCH_INTERNAL_ASSERT( + to->alias_to == nullptr, + "Multi-hop aliases are not supported. Attempted to alias ", + from->alloc_expr->buffer()->toString(), + " to ", + to->alloc_expr->buffer()->toString(), + " which is already aliased to ", + to->alias_to->buffer()->toString()); alias_map_[from] = to; from->alias_to = to->alloc_expr; to->outer_aliased_by.push_back(from); From 00f9342a379886b068c08ecb1c69394d47e3437d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 13:41:31 -0400 Subject: [PATCH 20/28] Use ForLoop::isTrivial --- csrc/device_lower/pass/alias_memory.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 4f5b6bada0a..d263b513122 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -763,7 +763,7 @@ class AllocationInfoMap : private kir::IrVisitor { void handle(kir::ForLoop* for_loop) final { auto loop_info = scope_map_.getLoopScopeInfo(for_loop); - if (!for_loop->iter_domain()->isParallelized()) { + if (!for_loop->isTrivial()) { // Parallelized loops do not result in for loops in the CUDA kernel, so // they should not affect liveness analysis. This means that // current_stack_ will differ from kir::IrVisitor::for_loops_, which will @@ -777,7 +777,7 @@ class AllocationInfoMap : private kir::IrVisitor { if (debug_printer_) { debug_printer_->popScope(); } - if (!for_loop->iter_domain()->isParallelized()) { + if (!for_loop->isTrivial()) { current_stack_.pop_back(); } } From 6c6459f81aedbd02b0637e0d826ea3937815aed9 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 14:26:56 -0400 Subject: [PATCH 21/28] Remove position_ member --- csrc/device_lower/pass/alias_memory.cpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index d263b513122..98de84e7a4e 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1902,14 +1902,14 @@ class PromoteReuseSyncModifier : private kir::ExprMutator { kir::ExprMutator::dispatch(expr); } - position_ = allocation_info_map_.getScopeMap().getExprPos(expr); + auto position = allocation_info_map_.getScopeMap().getExprPos(expr); // If this is an upcoming first write that has not yet been erased, it means // we have not seen a sync in its interval. So we should insert a BlockSync // before this expr. - if (upcoming_first_writes_.erase(position_)) { + if (upcoming_first_writes_.erase(position)) { if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Inserting block sync before position " << position_ + debug() << "Inserting block sync before position " << position << std::endl; } auto new_sync = IrBuilder::create(); @@ -1921,17 +1921,17 @@ class PromoteReuseSyncModifier : private kir::ExprMutator { // writes since they can be considered safe. if (lower_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Found blocking expression at position " << position_ + debug() << "Found blocking expression at position " << position << std::endl; } upcoming_first_writes_.clear(); } // If this is the lower endpoint of a sync interval, add the upper endpoint - auto range = sync_intervals_.equal_range(position_); + auto range = sync_intervals_.equal_range(position); for (auto& it = range.first; it != range.second; ++it) { if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Found dependency last read at position " << position_ + debug() << "Found dependency last read at position " << position << " corresponding to first write at " << it->second << std::endl; } @@ -1949,9 +1949,6 @@ class PromoteReuseSyncModifier : private kir::ExprMutator { // thread proceeds past the end of the interval. std::unordered_multimap sync_intervals_; - // Position within the traversal - int position_ = -1; - // These are the upper endpoints of needed sync intervals for which we've // already passed over the lower endpoint. std::unordered_set upcoming_first_writes_; From e0ab062d085e17207a615d2e58a77dd57c5e053a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 14:27:46 -0400 Subject: [PATCH 22/28] Rename getPromoteReuse to shouldPromoteReuse --- csrc/device_lower/pass/alias_memory.cpp | 3 ++- csrc/ir/interface_nodes.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 98de84e7a4e..cad33da58d8 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1846,7 +1846,8 @@ class PromoteReuseSyncModifier : private kir::ExprMutator { // subsequent allocations. for (const auto& alloc_info : allocation_info_map.allAllocationInfos()) { auto tv = alloc_info->alloc_expr->buffer()->as(); - if (tv->getMemoryType() != MemoryType::Shared || !tv->getPromoteReuse()) { + if (tv->getMemoryType() != MemoryType::Shared || + !tv->shouldPromoteReuse()) { continue; } auto last_read = alloc_info->getAliasedOuterLastRead(); diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 0ab47b6eb15..c69d831868d 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -518,7 +518,7 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! Returns whether we should insert syncs if needed in order to reuse the //! memory of this tensor. - bool getPromoteReuse() const { + bool shouldPromoteReuse() const { return promote_reuse_; } From 91fb41442c84dbb47a9bc0e4a9a0b67318b5d9f8 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 14:29:24 -0400 Subject: [PATCH 23/28] Assert on memory type in promoteReuse() --- csrc/ir/interface_nodes.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index c69d831868d..e980775e4a4 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -513,6 +513,9 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! is present in the kernel to reuse memory and inserts new block //! synchronizations if necessary. void promoteReuse(bool b = true) { + TORCH_CHECK( + memory_type_ == MemoryType::Shared, + "promoteReuse should only be called on shared memory tensors"); promote_reuse_ = b; } From 084fe5418a4314d34004fef235562fac3f06aedf Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 15:05:33 -0400 Subject: [PATCH 24/28] Remove TODO comment --- test/test_smem_reuse.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 4565e036159..4ae1c546161 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -277,12 +277,6 @@ TEST_F(SmemReuseTest, PromoteReuse) { } } -// TODO: Add tests with -// - one promoted tensor with multiple other tensors that could reuse the -// promoted tensor -// - multiple promoted tensors -// - aliased promoted tensor showing when we can re-use - // In this example, we promote a single tensor for re-use in a Fusion with two // downstream tensors that could use its memory. The first downstream tensor is // not re-used since it is not promoted. From cc7b03e5b5e9ad370352dd63295636f72b687577 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 15:28:18 -0400 Subject: [PATCH 25/28] Skip future allocs except unaliased smem allocs --- csrc/device_lower/pass/alias_memory.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index cad33da58d8..e4d1af89176 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1855,6 +1855,10 @@ class PromoteReuseSyncModifier : private kir::ExprMutator { std::optional nearest_first_write = std::nullopt; for (const auto& other : allocation_info_map.allAllocationInfos()) { + if (other->alias_to || other->mem_type != MemoryType::Shared) { + // Skip other if it aliases an earlier allocation + continue; + } auto first_write = other->outer_live_interval->firstWrite(); if (first_write <= last_read) { continue; From 7e7456017b5dc16e2434f8b60bedfb7262987a43 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 15:28:35 -0400 Subject: [PATCH 26/28] Skip checking inserted syncs in dispatch(Expr*) --- csrc/device_lower/pass/alias_memory.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index e4d1af89176..40a82a6e9c4 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -1899,14 +1899,6 @@ class PromoteReuseSyncModifier : private kir::ExprMutator { using kir::ExprMutator::dispatch; void dispatch(Expr* expr) final { - // Skip inserted syncs - if (inserted_syncs_.find(expr) != inserted_syncs_.end()) { - if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Skipping new sync expression " << expr->toString(); - } - kir::ExprMutator::dispatch(expr); - } - auto position = allocation_info_map_.getScopeMap().getExprPos(expr); // If this is an upcoming first write that has not yet been erased, it means From d6504a276da5b00e927c4e1890141e48ce0985ba Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 19:24:15 -0400 Subject: [PATCH 27/28] Refactor test definition into function --- test/test_smem_reuse.cpp | 42 +++++++++++++--------------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 4ae1c546161..e891566f70f 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -140,11 +140,8 @@ TEST_F(SmemReuseTest, SimpleCase) { // +-+-----+ // a b c * d e f // -TEST_F(SmemReuseTest, NeedsReorderedPush) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - int64_t H = 5; +std::tuple needsReorderedPushDefinition(int64_t H) { + auto fusion = FusionGuard::getCurFusion(); auto tv0 = full( {IrBuilder::create(H)}, @@ -170,6 +167,16 @@ TEST_F(SmemReuseTest, NeedsReorderedPush) { auto tv7 = neg(tv5); // pos = f fusion->addOutput(tv7); + return {tv0, tv3}; +} + +TEST_F(SmemReuseTest, NeedsReorderedPush) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t H = 5; + auto [tv0, tv3] = needsReorderedPushDefinition(H); + { // This should not re-use memory GpuLower gpulw(fusion.get()); @@ -214,30 +221,7 @@ TEST_F(SmemReuseTest, PromoteReuse) { FusionGuard fg(fusion.get()); int64_t H = 5; - - auto tv0 = full( - {IrBuilder::create(H)}, - fusion->oneVal(), - DataType::Float); // pos = a. A = tv0 - tv0->setMemoryType(MemoryType::Shared); - - auto tv1 = - pad(tv0, {fusion->zeroVal(), fusion->oneVal()}); // pos = b. B = tv1 - tv1->setMemoryType(MemoryType::Shared); - - auto tv2 = mul(tv1, tv1); // pos = c - - auto tv3 = sum(tv2, {0}); // gap between b and c. Can parallelize to sync - - auto tv4 = broadcast(tv3, {true}); - auto tv5 = mul(tv4, tv1); // pos = d. C = tv5 - tv5->setMemoryType(MemoryType::Shared); - - auto tv6 = add(tv1, tv1); // pos = e - fusion->addOutput(tv6); - - auto tv7 = neg(tv5); // pos = f - fusion->addOutput(tv7); + auto [tv0, tv3] = needsReorderedPushDefinition(H); { // This should not re-use memory GpuLower gpulw(fusion.get()); From 9a5260f724ab95d5ffd3c22058b6fd007fd54fed Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 19:33:23 -0400 Subject: [PATCH 28/28] Undo change that skips trivial loops --- csrc/device_lower/pass/alias_memory.cpp | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 40a82a6e9c4..0fbe3a65240 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -763,13 +763,7 @@ class AllocationInfoMap : private kir::IrVisitor { void handle(kir::ForLoop* for_loop) final { auto loop_info = scope_map_.getLoopScopeInfo(for_loop); - if (!for_loop->isTrivial()) { - // Parallelized loops do not result in for loops in the CUDA kernel, so - // they should not affect liveness analysis. This means that - // current_stack_ will differ from kir::IrVisitor::for_loops_, which will - // actually hold all ForLoops regardless of parallelization. - current_stack_.push_back(loop_info); - } + current_stack_.push_back(loop_info); if (debug_printer_) { debug_printer_->pushScope(); } @@ -777,9 +771,7 @@ class AllocationInfoMap : private kir::IrVisitor { if (debug_printer_) { debug_printer_->popScope(); } - if (!for_loop->isTrivial()) { - current_stack_.pop_back(); - } + current_stack_.pop_back(); } void handle(kir::IfThenElse* ite) final {