diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index b1f17c436a9..0fbe3a65240 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -31,7 +31,7 @@ // - Find allocations of tensors that can reuse other allocations // (class ReusableAllocationFinder) // - Replace those allocation expressions with their alias fields -// pointing to reused allocations (class AllocationReuseModifier) +// pointing to reused allocations (class AllocationAliasModifier) namespace nvfuser { @@ -571,9 +571,22 @@ struct AllocationInfo { std::unique_ptr outer_live_interval = nullptr; std::unique_ptr outer_subscribed_intevals = nullptr; + // Holds allocations that have alloc_expr as their alias_to + std::vector outer_aliased_by; + + //! Get the last outer read position of either this allocation, or any + //! allocation that is aliased to this allocation. + int getAliasedOuterLastRead() const { + auto last_outer_read = outer_live_interval->lastRead(); + for (auto aliasing : outer_aliased_by) { + last_outer_read = + std::max(last_outer_read, aliasing->outer_live_interval->lastRead()); + } + return last_outer_read; + } }; -class AllocationReuseModifier; +class AllocationAliasModifier; //! Analysis pass to collect the liveness info of local and shared buffers: //! The liveness info is illustrated as follows: @@ -627,11 +640,10 @@ class AllocationInfoMap : private kir::IrVisitor { current_stack_.pop_back(); } - std::optional getMaybeAllocationInfo( - const kir::Allocate* alloc) const { + AllocationInfo* getAllocationInfo(const kir::Allocate* alloc) const { auto it = allocation_info_map_.find(alloc); if (it == allocation_info_map_.end()) { - return std::nullopt; + return nullptr; } return it->second; } @@ -702,8 +714,16 @@ class AllocationInfoMap : private kir::IrVisitor { return all_allocations_; } + AllocationInfo* getAllocInfoFromTV(TensorView* tv) const { + auto alloc_it = tv_to_allocation_map_.find(tv->name()); + if (alloc_it == tv_to_allocation_map_.end()) { + return nullptr; + } + return alloc_it->second; + } + protected: - friend AllocationReuseModifier; + friend AllocationAliasModifier; //! When an allocation is registered for replacement, this method should be //! called to update the allocation info so that subsequent lookups behave @@ -847,61 +867,50 @@ class AllocationInfoMap : private kir::IrVisitor { // expr. The current analysis isn't enough to capture // their liveness range. for (auto input_tv : ir_utils::filterByType(expr->inputs())) { - auto maybe_alloc_info = getMaybeAllocInfoFromTV(input_tv); - if (maybe_alloc_info.has_value()) { + auto alloc_info = getAllocInfoFromTV(input_tv); + if (alloc_info) { if (!isSerialBroadcastResolution(input_tv, for_loops_)) { - maybe_alloc_info.value()->inner_live_interval->markRead(expr_pos); + alloc_info->inner_live_interval->markRead(expr_pos); } else { // Disable inner alias info for this buffer, since line number based // analysis is no longer precise enough for inplace sharing // if a serial broadcast is realized. - maybe_alloc_info.value()->can_use_inner_alias = false; + alloc_info->can_use_inner_alias = false; } - auto outer_loop_info = - ascendLoopNestToSameLevelAs(maybe_alloc_info.value()); + auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info); if (outer_loop_info) { - maybe_alloc_info.value()->outer_live_interval->markRead( - outer_loop_info->end_pos); + alloc_info->outer_live_interval->markRead(outer_loop_info->end_pos); } else { // Allocate is inlined in the innermost loop, // so outer live interval is the same as inner. - maybe_alloc_info.value()->outer_live_interval->markRead(expr_pos); + alloc_info->outer_live_interval->markRead(expr_pos); } } } for (auto output_tv : ir_utils::filterByType(expr->outputs())) { - auto maybe_alloc_info = getMaybeAllocInfoFromTV(output_tv); - if (maybe_alloc_info.has_value()) { + auto alloc_info = getAllocInfoFromTV(output_tv); + if (alloc_info) { // Reductions use outputs as read-write parameters, so their // outputs need to be marked as read as well const bool is_read_write = ir_utils::isReductionOp(expr); - maybe_alloc_info.value()->inner_live_interval->markWrite(expr_pos); + alloc_info->inner_live_interval->markWrite(expr_pos); if (is_read_write) { - maybe_alloc_info.value()->inner_live_interval->markRead(expr_pos); + alloc_info->inner_live_interval->markRead(expr_pos); } - auto outer_loop_info = - ascendLoopNestToSameLevelAs(maybe_alloc_info.value()); + auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info); auto write_pos = outer_loop_info ? outer_loop_info->start_pos : expr_pos; - maybe_alloc_info.value()->outer_live_interval->markWrite(write_pos); + alloc_info->outer_live_interval->markWrite(write_pos); if (is_read_write) { auto read_pos = outer_loop_info ? outer_loop_info->end_pos : expr_pos; - maybe_alloc_info.value()->outer_live_interval->markRead(read_pos); + alloc_info->outer_live_interval->markRead(read_pos); } } } } - std::optional getMaybeAllocInfoFromTV(TensorView* tv) const { - auto alloc_it = tv_to_allocation_map_.find(tv->name()); - if (alloc_it == tv_to_allocation_map_.end()) { - return std::nullopt; - } - return alloc_it->second; - } - //! Find the loop level of expr that apears in the same scope as //! the reference allocate. Eg. //! @@ -940,8 +949,17 @@ class AllocationInfoMap : private kir::IrVisitor { //! Mark the tensor of "from" be an alias of the tensor of "to". void setAlias(AllocationInfo* from, AllocationInfo* to) { + TORCH_INTERNAL_ASSERT( + to->alias_to == nullptr, + "Multi-hop aliases are not supported. Attempted to alias ", + from->alloc_expr->buffer()->toString(), + " to ", + to->alloc_expr->buffer()->toString(), + " which is already aliased to ", + to->alias_to->buffer()->toString()); alias_map_[from] = to; from->alias_to = to->alloc_expr; + to->outer_aliased_by.push_back(from); } private: @@ -974,14 +992,15 @@ void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { TORCH_INTERNAL_ASSERT(allocation_info_map_ != nullptr); std::string message_header(" \033[1;32m^^^^^ ---Buffer Reuse Info--- "); std::string message_end(" \033[0m\n"); - if (!allocation_info_map_->getMaybeAllocationInfo(alloc).has_value()) { + + auto alloc_info = allocation_info_map_->getAllocationInfo(alloc); + + if (!alloc_info) { // This buffer is not considered for any sharing, either // because of un-supported op or size below threshold. return; } - auto alloc_info = allocation_info_map_->getMaybeAllocationInfo(alloc).value(); - indent() << message_header; if (alloc_info->alias_to) { if (alloc_info->is_inner_alias) { @@ -990,8 +1009,7 @@ void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { os_ << "(outer) "; } os_ << " alias to alloc at pos " - << allocation_info_map_->getMaybeAllocationInfo(alloc_info->alias_to) - .value() + << allocation_info_map_->getAllocationInfo(alloc_info->alias_to) ->alloc_pos << " "; } else { @@ -1056,17 +1074,14 @@ class ReusableAllocationFinder : private kir::IrVisitor { // Check that if this allocation site is one that // we want to re-use or replace with an alias - auto maybe_alloc_info = - allocation_info_map_.getMaybeAllocationInfo(allocate); - if (maybe_alloc_info.has_value() && - maybe_alloc_info.value()->alias_to == nullptr) { + auto alloc_info = allocation_info_map_.getAllocationInfo(allocate); + if (alloc_info && alloc_info->alias_to == nullptr) { // Try to re-use existing allocates - if (!tryReuseOtherAllocate(maybe_alloc_info.value())) { + if (!tryReuseOtherAllocate(alloc_info)) { // If didn't re-use, should register this // allocate so that future allocates // can re-use this one. - current_visible_buffer_stack_.back()->push_back( - maybe_alloc_info.value()); + current_visible_buffer_stack_.back()->push_back(alloc_info); } } } @@ -1174,12 +1189,10 @@ class ReusableAllocationFinder : private kir::IrVisitor { } } - // TODO: - // Outer interval based sharing supports arbitrary re-indexing into - // the same buffer and would require additional syncs if fully - // enabled. - // Need a few more checks to insert syncs if necessary before turning - // on this sharing. + // Outer aliasing of shared memory requires thread block synchronization + // since it could involve arbitrary re-indexing. Instead, we will leave + // this type of re-use to the allocation phase. See + // assignSharedMemoryAllocations and promoteReuseSyncs. if (!inner_aliasing_pass_ && alloc_info->mem_type == MemoryType::Shared) { continue; @@ -1359,17 +1372,17 @@ class ReusableAllocationFinder : private kir::IrVisitor { }; // Replace Allocate exprs as determined by the alias analysis -class AllocationReuseModifier : private kir::ExprMutator { +class AllocationAliasModifier : private kir::ExprMutator { public: static std::vector modify( const std::vector& exprs, AllocationInfoMap& allocation_info_map) { - AllocationReuseModifier modifier(exprs, allocation_info_map); + AllocationAliasModifier modifier(exprs, allocation_info_map); return modifier.exprs_; } private: - AllocationReuseModifier( + AllocationAliasModifier( const std::vector& exprs, AllocationInfoMap& allocation_info_map) : allocation_info_map_(allocation_info_map) { @@ -1380,14 +1393,11 @@ class AllocationReuseModifier : private kir::ExprMutator { //! Replace an kir::Allocate with a new aliased Allocate void handle(kir::Allocate* allocate) final { - auto maybe_alloc_info = - allocation_info_map_.getMaybeAllocationInfo(allocate); - if (!maybe_alloc_info.has_value()) { + auto alloc_info_from = allocation_info_map_.getAllocationInfo(allocate); + if (!alloc_info_from) { return; } - AllocationInfo* alloc_info_from = maybe_alloc_info.value(); - auto alias_it = allocation_info_map_.getAliasMap().find(alloc_info_from); if (alias_it == allocation_info_map_.getAliasMap().end()) { return; @@ -1608,7 +1618,7 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { StackBasedSharedMemAllocator(const AllocationInfoMap& allocation_info_map) : allocation_info_map_(allocation_info_map) {} - void allocate(std::vector& exprs) { + void allocate(const std::vector& exprs) { recordEvents(); // Traverse expressions: reclaim memory when we pass a blockSync, append to @@ -1635,6 +1645,10 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { // Reclaim memory whenever we pass an Expr that is known to synchronize the // block if (lower_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Block syncing expr found at position " << position_ + << ". Reclaiming memory." << std::endl; + } reclaimMemory(); } @@ -1671,6 +1685,12 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { } void pushAndAssign(AllocationInfo* alloc_info) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + auto alloc = alloc_info->alloc_expr; + debug() << "Pushing allocation for T" << alloc->buffer()->name() + << std::endl; + } + // Assign new address assignNextAddress(alloc_info); @@ -1691,8 +1711,11 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { alloc->setAddress(aligned_address); } if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { - debug() << "Allocated address " << alloc->address()->toInlineString() - << " for T" << alloc->buffer()->name() << std::endl; + debug() << "Assigned address " << alloc->address()->toInlineString() + << " for T" << alloc->buffer()->name() << " with size " + << alloc->size()->toInlineString() << " * " + << dataTypeSize(alloc->buffer()->dtype()) << " bytes" + << std::endl; } } @@ -1704,10 +1727,10 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { } if (alloc_info->alias_to) { auto alias_info = - allocation_info_map_.getMaybeAllocationInfo(alloc_info->alias_to); - TORCH_CHECK(alias_info.has_value()); - auto prev_last_read = lastAliasedRead(alias_info.value()); - last_aliased_read_[alias_info.value()] = std::max( + allocation_info_map_.getAllocationInfo(alloc_info->alias_to); + TORCH_CHECK(alias_info); + auto prev_last_read = lastAliasedRead(alias_info); + last_aliased_read_[alias_info] = std::max( prev_last_read, alloc_info->outer_live_interval->lastRead()); } else { last_aliased_read_[alloc_info.get()] = @@ -1750,6 +1773,12 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { while (!alloc_stack_.empty()) { auto last_read = lastAliasedRead(alloc_stack_.back()); if (last_read <= position_) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + auto alloc = alloc_stack_.back()->alloc_expr; + debug() << "Popping allocation for T" << alloc->buffer()->name() + << " which has assigned address " + << alloc->address()->toInlineString() << std::endl; + } alloc_stack_.pop_back(); } else { break; @@ -1788,6 +1817,170 @@ class StackBasedSharedMemAllocator : kir::IrVisitor { } // namespace +// Use allocation info map to find aliases, i.e. allocations that are properly +// sized and parallelized so that they can be re-used without any +// synchronization. +std::vector aliasMemoryAllocations( + const std::vector& exprs, + AllocationInfoMap& allocation_info_map) { + ReusableAllocationFinder::find(exprs, allocation_info_map); + return AllocationAliasModifier::modify(exprs, allocation_info_map); +} + +class PromoteReuseSyncModifier : private kir::ExprMutator { + public: + PromoteReuseSyncModifier( + const std::vector& exprs, + const AllocationInfoMap& allocation_info_map) + : allocation_info_map_(allocation_info_map) { + // Find next allocation after last aliased read of all allocations whose + // reuse we need to promote, and record shortest sync intervals relative to + // subsequent allocations. + for (const auto& alloc_info : allocation_info_map.allAllocationInfos()) { + auto tv = alloc_info->alloc_expr->buffer()->as(); + if (tv->getMemoryType() != MemoryType::Shared || + !tv->shouldPromoteReuse()) { + continue; + } + auto last_read = alloc_info->getAliasedOuterLastRead(); + + std::optional nearest_first_write = std::nullopt; + + for (const auto& other : allocation_info_map.allAllocationInfos()) { + if (other->alias_to || other->mem_type != MemoryType::Shared) { + // Skip other if it aliases an earlier allocation + continue; + } + auto first_write = other->outer_live_interval->firstWrite(); + if (first_write <= last_read) { + continue; + } + if (!nearest_first_write.has_value() || + first_write < nearest_first_write.value()) { + nearest_first_write = first_write; + } + } + + if (nearest_first_write.has_value()) { + sync_intervals_.emplace(last_read, nearest_first_write.value()); + } + } + + if (sync_intervals_.empty()) { + exprs_ = exprs; + } else { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Ensuring syncs within these intervals:" << std::endl; + for (auto [last_read, first_write] : sync_intervals_) { + debug() << " " << last_read << " - " << first_write << std::endl; + } + } + traverseAndInsert(exprs); + } + } + + const std::unordered_set& insertedSyncs() const { + return inserted_syncs_; + } + + const std::vector& modifiedExprs() const { + return exprs_; + } + + private: + using kir::ExprMutator::dispatch; + + void dispatch(Expr* expr) final { + auto position = allocation_info_map_.getScopeMap().getExprPos(expr); + + // If this is an upcoming first write that has not yet been erased, it means + // we have not seen a sync in its interval. So we should insert a BlockSync + // before this expr. + if (upcoming_first_writes_.erase(position)) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Inserting block sync before position " << position + << std::endl; + } + auto new_sync = IrBuilder::create(); + inserted_syncs_.insert(new_sync); + registerInsertBefore(expr, new_sync); + } + + // If we have a sync at this location, we can clear any upcoming first + // writes since they can be considered safe. + if (lower_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Found blocking expression at position " << position + << std::endl; + } + upcoming_first_writes_.clear(); + } + + // If this is the lower endpoint of a sync interval, add the upper endpoint + auto range = sync_intervals_.equal_range(position); + for (auto& it = range.first; it != range.second; ++it) { + if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) { + debug() << "Found dependency last read at position " << position + << " corresponding to first write at " << it->second + << std::endl; + } + upcoming_first_writes_.insert(it->second); + } + + kir::ExprMutator::dispatch(expr); + } + + private: + const AllocationInfoMap& allocation_info_map_; + + // This holds intervals in which we need to ensure a sync exists. All + // threads in a block should arrive at the start of each interval before any + // thread proceeds past the end of the interval. + std::unordered_multimap sync_intervals_; + + // These are the upper endpoints of needed sync intervals for which we've + // already passed over the lower endpoint. + std::unordered_set upcoming_first_writes_; + + // Holds all new syncs we have inserted + std::unordered_set inserted_syncs_; +}; + +// Insert missing synchronizations in cases where a TensorView is marked as +// needing reuse promotion. This should be done before +// allocateSharedMemoryAllocations, which uses synchronization points to reclaim +// unused shared memory. +std::pair, bool> promoteReuseSyncs( + const std::vector& exprs, + AllocationInfoMap& allocation_info_map) { + auto modifier = PromoteReuseSyncModifier(exprs, allocation_info_map); + return {modifier.modifiedExprs(), !modifier.insertedSyncs().empty()}; +} + +// Assign addresses for dynamic shared memory allocations. This re-uses memory +// by reclaiming memory that is unused when encountering a block +// synchronization. +void assignSharedMemoryAllocations( + const std::vector& exprs, + AllocationInfoMap& allocation_info_map) { + StackBasedSharedMemAllocator(allocation_info_map).allocate(exprs); + + // Verify that all smem allocations have a non-null address now + for (auto& alloc_info : allocation_info_map.allAllocationInfos()) { + if (alloc_info->mem_type != MemoryType::Shared || alloc_info->alias_to) { + continue; + } + auto alloc = alloc_info->alloc_expr; + TORCH_INTERNAL_ASSERT( + alloc->address(), + "Unaliased allocation for shared memory tensor ", + alloc->buffer()->toString(), + " was not assigned an address"); + } +} + +// Entry point for all memory re-use including unsynced aliasing as well as +// insertion of requested syncs and memory allocation with reclamation. std::vector reuseMemoryAllocations(const std::vector& exprs) { FUSER_PERF_SCOPE("reuseMemoryAllocations"); @@ -1795,14 +1988,21 @@ std::vector reuseMemoryAllocations(const std::vector& exprs) { AllocationInfoMap allocation_info_map(exprs, debug_print); - ReusableAllocationFinder::find(exprs, allocation_info_map); + const auto aliased_exprs = aliasMemoryAllocations(exprs, allocation_info_map); - auto aliased_exprs = - AllocationReuseModifier::modify(exprs, allocation_info_map); + const auto [synced_exprs, inserted_syncs] = + promoteReuseSyncs(aliased_exprs, allocation_info_map); + + // If we inserted sync expressions, we need to recompute positions of any + // downstream expressions. Rather than try to keep those in sync, we just + // recompute the allocation info map here. + if (inserted_syncs) { + allocation_info_map = AllocationInfoMap(synced_exprs, false); + } - StackBasedSharedMemAllocator(allocation_info_map).allocate(aliased_exprs); + assignSharedMemoryAllocations(synced_exprs, allocation_info_map); - return aliased_exprs; + return synced_exprs; } } // namespace nvfuser diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 2e0ee3f9953..79cf79d736d 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -505,6 +505,26 @@ class TORCH_CUDA_CU_API TensorView : public Val { // ensure consistency. void commitLeafToRFactor(); + //! Request that we reclaim the memory of this tv before any subsequent + //! tensors are allocated. + //! + //! This method influences the shared memory allocator that assigns shared + //! memory addresses at lowering. It ensures that the proper synchronization + //! is present in the kernel to reuse memory and inserts new block + //! synchronizations if necessary. + void promoteReuse(bool b = true) { + TORCH_CHECK( + memory_type_ == MemoryType::Shared, + "promoteReuse should only be called on shared memory tensors"); + promote_reuse_ = b; + } + + //! Returns whether we should insert syncs if needed in order to reuse the + //! memory of this tensor. + bool shouldPromoteReuse() const { + return promote_reuse_; + } + protected: void setDomain(TensorDomain* td) { domain_ = td; @@ -567,6 +587,15 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! transformed when there may actually be a producer tensor that //! may be computed at. unsigned int maybe_max_producer_pos_ = 0; + + //! When this is true, it indicates, if this is a shared memory tensor and + //! there other shared memory tensors whose lifetimes do not overlap and come + //! later than this tensor's lifetime, that we should ensure that thread + //! blocks are synchronized such that all threads have performed their last + //! read of this tensor (or any tensors aliasing in) before writing to the + //! current tensor. This will then allow us to safely reuse the memory + //! allocated to this tensor. + bool promote_reuse_ = false; }; //! A simple TensorView builder diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 707e71f36a7..2a22988d736 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -252,7 +252,8 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) cpu_scalar_(src->cpu_scalar_), has_swizzle_op_(src->has_swizzle_op_), compute_with_consumers_(ir_cloner->clone(src->compute_with_consumers_)), - compute_with_pos_(src->compute_with_pos_) {} + compute_with_pos_(src->compute_with_pos_), + promote_reuse_(src->promote_reuse_) {} // sets cpu_scalar_ value, which is special handling for CPU based zero-dim // tensors (i.e. CPU Tensors that only have one value). This is only used if diff --git a/test/test_smem_reuse.cpp b/test/test_smem_reuse.cpp index 5781dfe754d..e891566f70f 100644 --- a/test/test_smem_reuse.cpp +++ b/test/test_smem_reuse.cpp @@ -107,6 +107,7 @@ TEST_F(SmemReuseTest, SimpleCase) { ExpressionEvaluator ee; int64_t smem_usage = 0; for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); auto addr = ee.evaluate(alloc->address()).as(); auto size = ee.evaluate(alloc->size()).as() * dataTypeSize(alloc->buffer()->dtype()); @@ -139,39 +140,235 @@ TEST_F(SmemReuseTest, SimpleCase) { // +-+-----+ // a b c * d e f // +std::tuple needsReorderedPushDefinition(int64_t H) { + auto fusion = FusionGuard::getCurFusion(); + + auto tv0 = full( + {IrBuilder::create(H)}, + fusion->oneVal(), + DataType::Float); // pos = a. A = tv0 + tv0->setMemoryType(MemoryType::Shared); + + auto tv1 = + pad(tv0, {fusion->zeroVal(), fusion->oneVal()}); // pos = b. B = tv1 + tv1->setMemoryType(MemoryType::Shared); + + auto tv2 = mul(tv1, tv1); // pos = c + + auto tv3 = sum(tv2, {0}); // gap between b and c. Can parallelize to sync + + auto tv4 = broadcast(tv3, {true}); + auto tv5 = mul(tv4, tv1); // pos = d. C = tv5 + tv5->setMemoryType(MemoryType::Shared); + + auto tv6 = add(tv1, tv1); // pos = e + fusion->addOutput(tv6); + + auto tv7 = neg(tv5); // pos = f + fusion->addOutput(tv7); + + return {tv0, tv3}; +} + TEST_F(SmemReuseTest, NeedsReorderedPush) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - int64_t H_int = 5, W_int = 6; - auto H = IrBuilder::create(H_int); - auto W = IrBuilder::create(W_int); + int64_t H = 5; + auto [tv0, tv3] = needsReorderedPushDefinition(H); - auto tv0 = full({H}, fusion->oneVal(), DataType::Float); - auto tv1 = set(tv0); // pos = a. A = tv1 - tv1->setMemoryType(MemoryType::Shared); + { // This should not re-use memory + GpuLower gpulw(fusion.get()); + + ExpressionEvaluator ee; + std::unordered_set addresses; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + TORCH_CHECK( + addresses.insert(addr).second, + "Smem addresses should not be re-used"); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ( + smem_usage, alignInt(alignInt((H + 1) * 4) + (H + 1) * 4) + H * 4); + } + + { // Now introduce a block reduction and check that we re-use memory + tv3->axis(0)->parallelize(ParallelType::TIDx); + + GpuLower gpulw(fusion.get()); + ExpressionEvaluator ee; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + (H + 1) * 4); + } +} - auto tv2 = full({W}, fusion->oneVal(), DataType::Float); // pos = b. B = tv2 +// Same as NeedsReorderedPush but C requests to reuse A instead of pre-existing +// sync +TEST_F(SmemReuseTest, PromoteReuse) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t H = 5; + auto [tv0, tv3] = needsReorderedPushDefinition(H); + + { // This should not re-use memory + GpuLower gpulw(fusion.get()); + + ExpressionEvaluator ee; + std::unordered_set addresses; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + TORCH_CHECK( + addresses.insert(addr).second, + "Smem addresses should not be re-used"); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ( + smem_usage, alignInt(alignInt((H + 1) * 4) + (H + 1) * 4) + H * 4); + } + + { // Request that we re-use the allocation for tv0. This should place a + // __syncthreads() just before tv5 is written. + tv0->promoteReuse(); + + GpuLower gpulw(fusion.get()); + ExpressionEvaluator ee; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + (H + 1) * 4); + } +} + +// In this example, we promote a single tensor for re-use in a Fusion with two +// downstream tensors that could use its memory. The first downstream tensor is +// not re-used since it is not promoted. +TEST_F(SmemReuseTest, PromoteReuseMultipleDownstream) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t H = 7; + + auto tv0 = + full({IrBuilder::create(H)}, fusion->oneVal(), DataType::Float); + tv0->setMemoryType(MemoryType::Shared); + + auto tv1 = neg(tv0); + + auto tv2 = pad(tv1, {fusion->zeroVal(), fusion->oneVal()}); tv2->setMemoryType(MemoryType::Shared); - auto tv3 = add(tv1, tv1); // pos = c + auto tv3 = neg(tv2); - auto tv4 = sum(tv3, {0}); // gap between b and c - fusion->addOutput(tv4); + auto tv4 = pad(tv3, {fusion->zeroVal(), fusion->oneVal()}); + tv4->setMemoryType(MemoryType::Shared); - auto tv5 = broadcast(tv4, {true}); - auto tv6 = mul(tv5, tv3); + auto tv5 = neg(tv4); - auto tv7 = broadcast(tv6, {true, false}); - auto tv8 = broadcast(tv2, {false, true}); - auto tv9 = mul(tv7, tv8); // pos = d. C = tv9 - tv9->setMemoryType(MemoryType::Shared); + fusion->addOutput(tv5); - auto tv10 = add(tv2, tv2); // pos = e - fusion->addOutput(tv10); + { // This should not re-use memory + GpuLower gpulw(fusion.get()); - auto tv11 = neg(tv9); // pos = f - fusion->addOutput(tv11); + ExpressionEvaluator ee; + std::unordered_set addresses; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + TORCH_CHECK( + addresses.insert(addr).second, + "Smem addresses should not be re-used"); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ( + smem_usage, alignInt(alignInt((H + 2) * 4) + (H + 1) * 4) + H * 4); + } + + { // Request that we re-use the allocation for tv0. This should place a + // __syncthreads() just before tv2 is written. + tv0->promoteReuse(); + + GpuLower gpulw(fusion.get()); + ExpressionEvaluator ee; + int64_t smem_usage = 0; + for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); + auto addr = ee.evaluate(alloc->address()).as(); + auto size = ee.evaluate(alloc->size()).as() * + dataTypeSize(alloc->buffer()->dtype()); + smem_usage = std::max(smem_usage, addr + size); + } + EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + (H + 2) * 4); + } +} + +// In this example, multiple smem tensors are promoted for re-use. We have +// non-overlapping smem allocations A B C D, and A and C are promoted for reuse. +// Because of that, B re-uses A, then C does not reuse B but stacks on top of +// it. Then D reuses C, and B is reclaimed in the process. Ultimately this means +// the assigned addresses are: +// +// A: 0. Assigned then reclaimed before assignment of B. +// B: alignInt((H + 2) * 4). Stacked on top of C +// C: 0. Assigned along with B in reverse order of last use +// D: 0. B and C are reclaimed before this assignment. +// +// Note that although B was not explicitly requested for re-use, since its +// lifetime ends before D is defined, we try and reclaim it at the same time C +// is reclaimed. They are also ordered on the stack at that point, in descending +// order of last use, meaning B is placed higher on the stack than C. +TEST_F(SmemReuseTest, MultiplePromoteReuse) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t H = 7; + + auto tv0 = + full({IrBuilder::create(H)}, fusion->oneVal(), DataType::Float); + tv0->setMemoryType(MemoryType::Shared); + + auto tv1 = neg(neg(tv0)); + + auto tv2 = pad(tv1, {fusion->zeroVal(), fusion->oneVal()}); + tv2->setMemoryType(MemoryType::Shared); + + auto tv3 = neg(neg(tv2)); + + auto tv4 = pad(tv3, {fusion->zeroVal(), fusion->oneVal()}); + tv4->setMemoryType(MemoryType::Shared); + + auto tv5 = neg(neg(tv4)); + + auto tv6 = pad(tv5, {fusion->zeroVal(), fusion->oneVal()}); + tv6->setMemoryType(MemoryType::Shared); + + auto tv7 = neg(tv6); + + fusion->addOutput(tv7); { // This should not re-use memory GpuLower gpulw(fusion.get()); @@ -180,6 +377,7 @@ TEST_F(SmemReuseTest, NeedsReorderedPush) { std::unordered_set addresses; int64_t smem_usage = 0; for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); auto addr = ee.evaluate(alloc->address()).as(); TORCH_CHECK( addresses.insert(addr).second, @@ -190,23 +388,26 @@ TEST_F(SmemReuseTest, NeedsReorderedPush) { } EXPECT_EQ( smem_usage, - alignInt(alignInt(H_int * W_int * 4) + W_int * 4) + H_int * 4); + alignInt(alignInt(alignInt((H + 3) * 4) + (H + 2) * 4) + (H + 1) * 4) + + H * 4); } - { // Now introduce a block reduction and check that we re-use memory - - tv4->axis(0)->parallelize(ParallelType::TIDx); + { // Request that we re-use A and C + tv0->promoteReuse(); + tv4->promoteReuse(); GpuLower gpulw(fusion.get()); ExpressionEvaluator ee; int64_t smem_usage = 0; for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) { + EXPECT_NE(alloc->address(), nullptr); auto addr = ee.evaluate(alloc->address()).as(); auto size = ee.evaluate(alloc->size()).as() * dataTypeSize(alloc->buffer()->dtype()); smem_usage = std::max(smem_usage, addr + size); } - EXPECT_EQ(smem_usage, alignInt(H_int * 4) + W_int * H_int * 4); + // High water mark has C stacked on top of B + EXPECT_EQ(smem_usage, alignInt((H + 2) * 4) + (H + 1) * 4); } }