diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index b95644d843f..c2166226dfb 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -109,13 +109,22 @@ struct Fusion::ContainerMutator { removeExpr(self, e); } + c->per_fusion_vals_[self].erase(val); + + // Multi-owner guard: only free if this is the last owning Fusion. + if (!val->removeOwningFusion(self)) { + // Other Fusions still own this Val — remove from our tracking but + // keep the Val alive in vals_up_ and vals_. + self->invalidateTvsAndUses(); + return; + } + auto val_in_deque = std::ranges::find_if( c->vals_up_, [val](std::unique_ptr& val_up) { return val_up.get() == val; }); NVF_ERROR( val_in_deque != c->vals_up_.end(), "Wanted to remove a value but its unique ptr is missing."); - c->per_fusion_vals_[self].erase(val); c->vals_.erase(val); c->vals_up_.erase(val_in_deque); @@ -137,6 +146,11 @@ struct Fusion::ContainerMutator { c->vals_.insert(val); c->per_fusion_vals_[self].insert(val); val->setName(IrContainerPasskey(), self->getValName(val->vtype())); + + // Seed owning_fusions_ with the registering Fusion (original creator). + if (val->owning_fusions_.empty()) { + val->owning_fusions_.push_back(self); + } } static void registerExpr(Fusion* self, Expr* expr) { @@ -159,7 +173,9 @@ struct Fusion::ContainerMutator { c->assertInContainerImpl(input, "Input to expr is invalid, "); if (input->isA()) { self->invalidateTvsAndUses(); - } else { + } else if (!input->isShared()) { + // Don't track uses on shared scalars — their uses_ would accumulate + // Exprs from multiple Fusions, causing cross-Fusion DAG leakage. input->addUse(expr); } } @@ -212,7 +228,8 @@ struct Fusion::ContainerMutator { Expr* e = c->exprs_up_.back().get(); NVF_ERROR( c->per_fusion_exprs_[self].count(e) > 0, - "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); + "removeStatementsCreatedAfter: tail expr belongs to another " + "Fusion"); for (Val* out : e->outputs()) { out->setDefinition(nullptr); } @@ -230,6 +247,11 @@ struct Fusion::ContainerMutator { "removeStatementsCreatedAfter: tail val belongs to another Fusion"); nullOutShortcutIfNeeded(self, v); c->per_fusion_vals_[self].erase(v); + if (!v->removeOwningFusion(self)) { + // Shared val — other Fusions still own it. Can't pop from deque + // tail since the val must stay alive. Move it out of the way. + break; + } c->vals_.erase(v); c->vals_up_.pop_back(); } @@ -276,6 +298,10 @@ struct Fusion::ContainerMutator { // self's new val — remove (null shortcut cache pointer if applicable) nullOutShortcutIfNeeded(self, v); c->per_fusion_vals_[self].erase(v); + // Multi-owner guard: only free if last owner. + if (!v->removeOwningFusion(self)) { + return false; // shared val — other Fusions still own it + } c->vals_.erase(v); return true; }); @@ -457,10 +483,16 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } // Wire up definitions and uses on cloned vals in deterministic order - // to ensure exprs are inserted into exprs_up_ deterministically + // to ensure exprs are inserted into exprs_up_ deterministically. + // Skip reused vals (shared scalars) — their definition/uses belong to + // the source Fusion and must not be overwritten. for (auto val : from->deterministic_vals()) { - ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); - ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); + auto* cloned = ir_cloner.clone(val); + if (cloned == val) { + continue; // reused (shared scalar) — don't rewire + } + cloned->setDefinition(ir_cloner.clone(val->definition_)); + cloned->setUses(ir_cloner.clone(val->uses_)); } // Sync per-Fusion name counters from source to dest. diff --git a/csrc/fusion.h b/csrc/fusion.h index c4436e11747..d6d2a3aa8f1 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -591,6 +591,7 @@ class NVF_API Fusion : public PolymorphicBase { friend SegmentCandidateFinder; friend SegmentedFusion; friend class TranslateApplicableWelford; + friend class IrCloner; friend Val; //! Constructor that shares an existing container. Creates an empty Fusion diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 0736b560078..16f0a09c867 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1801,8 +1801,8 @@ std::pair> SegmentedFusion::makeFusion( SegmentedGroup* sg) const { // TODO Optimize cloning step by only copying values and expressions between // the fusion segment's inputs and outputs. - auto fusion_segment = std::unique_ptr( - new Fusion(completeFusion()->ir_container_ptr())); + auto fusion_segment = + std::unique_ptr(new Fusion(completeFusion()->ir_container_ptr())); IrCloner complete_to_segment_map = Fusion::copy(completeFusion(), fusion_segment.get()); @@ -2803,8 +2803,10 @@ bool TranslateApplicableWelford::wouldTranslateToPersistent( [fusion](WelfordOp* welford) { return welford->fusion() == fusion; }), "Welfords in given vector not in the same fusion"); - // Make initial `in-progress copy` - auto test_copy = std::make_unique(); + // Make initial `in-progress copy` — share the source IrContainer so that + // traversal via getCurFusion() sees all Vals in the same container. + auto test_copy = + std::unique_ptr(new Fusion(fusion->ir_container_ptr())); auto original_to_test_map = Fusion::copy(fusion, test_copy.get()); std::vector copied_welfords; diff --git a/csrc/host_ir/container.h b/csrc/host_ir/container.h index 4a8456858fe..940c1aec01c 100644 --- a/csrc/host_ir/container.h +++ b/csrc/host_ir/container.h @@ -23,6 +23,8 @@ namespace nvfuser::hir { class HostIrContainer final : public Fusion { public: HostIrContainer() = default; + explicit HostIrContainer(std::shared_ptr container) + : Fusion(std::move(container)) {} HostIrContainer(const HostIrContainer&) = delete; HostIrContainer& operator=(const HostIrContainer&) = delete; diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index d4fab7017de..bcfcaa6c348 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -118,7 +118,8 @@ std::unique_ptr HostIrLower::lower( RuntimeWorkSpace workspace = prepareRuntimeOrder(*staged_fusion); // Create the HostIrContainer representing the host program. Each segment of // the segmented fusion will be translated to a HostIR - auto hic = std::make_unique(); + auto hic = std::make_unique( + staged_fusion->completeFusion()->ir_container_ptr()); FusionGuard fg(hic.get()); IrCloner ir_cloner(hic.get()); auto clone = diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index bb14c8c6eae..492043664d4 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -396,7 +396,8 @@ std::unique_ptr lowerSegmentedFusionToHostIr( const SegmentedFusion& segmented_fusion, const std::vector& launch_params_per_segment, std::vector>& executors) { - auto hic = std::make_unique(); + auto hic = std::make_unique( + segmented_fusion.completeFusion()->ir_container_ptr()); IrCloner ir_cloner = Fusion::copy(segmented_fusion.completeFusion(), hic.get()); diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 9ea67cfea86..3e3cf7790ed 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -90,6 +90,33 @@ kir::Kernel* Statement::kernel() const { NVFUSER_DEFINE_CLONE(Val) +bool Val::isOwnedBy(const Fusion* f) const { + return std::find(owning_fusions_.begin(), owning_fusions_.end(), f) != + owning_fusions_.end(); +} + +void Val::addOwningFusion(Fusion* f) { + if (!isOwnedBy(f)) { + bool was_unshared = !isShared(); + owning_fusions_.push_back(f); + if (was_unshared && isShared()) { + // Transitioning to shared: clear uses_ so it doesn't hold stale + // Expr pointers from the original creator Fusion. registerExpr() + // already skips addUse() for shared scalars, so uses_ would never + // be updated again — keeping it non-empty would be misleading. + uses_.clear(); + } + } +} + +bool Val::removeOwningFusion(Fusion* f) { + auto it = std::find(owning_fusions_.begin(), owning_fusions_.end(), f); + if (it != owning_fusions_.end()) { + owning_fusions_.erase(it); + } + return owning_fusions_.empty(); +} + void Val::addDependency(Val* dependency) { NVF_ERROR(dependency != nullptr); @@ -100,6 +127,10 @@ void Val::addDependency(Val* dependency) { } const std::vector& Val::uses() const { + // Shared scalars return an empty uses_ (cleared in addOwningFusion). + // registerExpr() skips addUse() for shared scalars, so uses_ is never + // populated after sharing. Callers get an empty vector, which is + // correct — per_fusion_exprs_ should be used for Fusion-specific traversal. if (vtype_ == ValType::TensorView) { if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) { fusion()->resetTvUses(); diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index c7d359dbae1..823ee61f563 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -411,6 +411,24 @@ class NVF_API Val : public Statement { definition_ = expr; } + // Multi-owner tracking for Phase 3 scalar sharing. + // owning_fusions_[0] = original creator (by convention, set at registration). + // Grows to 2+ only for shared scalars (IrCloner reuse path). + bool isShared() const { + return owning_fusions_.size() > 1; + } + + bool isOwnedBy(const Fusion* f) const; + + void addOwningFusion(Fusion* f); + + // Remove an owning Fusion. Returns true if this was the last owner. + bool removeOwningFusion(Fusion* f); + + const std::vector& owningFusions() const { + return owning_fusions_; + } + NVFUSER_DECLARE_CLONE protected: @@ -454,6 +472,9 @@ class NVF_API Val : public Statement { // welford operations. DataType dtype_; + // Tracks all Fusions that own this Val. Seeded at registration. + std::vector owning_fusions_; + // Following is managed by Fusion and can change. bool is_fusion_input_ = false; bool is_fusion_output_ = false; diff --git a/csrc/ir/cloner.cpp b/csrc/ir/cloner.cpp index c71c04f082c..beb7fb2cb43 100644 --- a/csrc/ir/cloner.cpp +++ b/csrc/ir/cloner.cpp @@ -24,21 +24,46 @@ Statement* IrCloner::clone(const Statement* statement) { return nullptr; } - // Have we already cloned this node? + // Step 1: Cache check — already cloned or reused const auto it = clones_map_.find(statement); if (it != clones_map_.end()) { return it->second; - } else { - auto new_node = handle(statement); - - // The base cloning constructor (Statement) should have - // registered the new node. Failure to do so indicates - // that something went horribly wrong. - NVF_ERROR(new_node != nullptr); - NVF_ERROR(clones_map_[statement] == new_node); + } - return new_node; + // Step 2: Scalar reuse — share leaf scalars across Fusions + if (statement->isVal()) { + const Val* val = statement->as(); + if (val->isScalar() && val->definition() == nullptr && + !val->value().hasValue() && !val->isFusionInput() && + !val->uses().empty()) { + Fusion* src_fusion = val->container(); + if (src_fusion != nullptr && src_fusion != ir_container_ && + src_fusion->ir_container() == ir_container_->ir_container()) { + Val* reused = const_cast(val); + + // (a) Cache so downstream Expr clones resolve this input + clones_map_[statement] = reused; + + // (b) Register with dest Fusion's per-Fusion tracking + auto* c = ir_container_->ir_container(); + { + std::unique_lock lock(c->mutex_); + c->per_fusion_vals_[ir_container_].insert(reused); + } + + // (c) Track ownership for lifetime management + reused->addOwningFusion(ir_container_); + + return reused; + } + } } + + // Step 3: Full clone (unchanged) + auto new_node = handle(statement); + NVF_ERROR(new_node != nullptr); + NVF_ERROR(clones_map_[statement] == new_node); + return new_node; } void IrCloner::registerClone(const Statement* src, Statement* clone) { diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index d33a8af4eef..ef2c7db9b9b 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -153,7 +153,8 @@ int64_t IrContainer::numVals() const noexcept { void IrContainer::addFusion(Fusion* fusion) { std::unique_lock lock(mutex_); sharing_fusions_.insert(fusion); - per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs during concurrent val/expr registration + per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs + // during concurrent val/expr registration per_fusion_exprs_[fusion]; } @@ -220,24 +221,22 @@ void IrContainer::transferStatementOwnership( void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) { std::unique_lock lock(mutex_); - auto vals_it = per_fusion_vals_.find(fusion); - if (vals_it != per_fusion_vals_.end()) { - const auto& owned = vals_it->second; - std::erase_if(vals_up_, [&](const std::unique_ptr& v) { - if (owned.count(v.get()) > 0) { - vals_.erase(v.get()); - return true; - } - return false; - }); - per_fusion_vals_.erase(vals_it); - } + // Process Exprs FIRST — clean up uses_/definition_ pointers on Vals + // before freeing Exprs. This prevents dangling pointers in shared + // scalars' uses_ vectors (shared scalars survive Val cleanup via the + // multi-owner guard but their uses_ would reference freed Exprs). auto exprs_it = per_fusion_exprs_.find(fusion); if (exprs_it != per_fusion_exprs_.end()) { const auto& owned = exprs_it->second; std::erase_if(exprs_up_, [&](const std::unique_ptr& e) { if (owned.count(e.get()) > 0) { + for (Val* out : e->outputs()) { + out->setDefinition(nullptr); + } + for (Val* inp : e->inputs()) { + inp->removeUse(e.get()); + } exprs_.erase(e.get()); return true; } @@ -245,6 +244,26 @@ void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) { }); per_fusion_exprs_.erase(exprs_it); } + + // Then Vals — shared vals survive via multi-owner guard, now with + // clean uses_ (dangling Expr pointers already removed above). + auto vals_it = per_fusion_vals_.find(fusion); + if (vals_it != per_fusion_vals_.end()) { + const auto& owned = vals_it->second; + std::erase_if(vals_up_, [&](const std::unique_ptr& v) { + if (owned.count(v.get()) > 0) { + // Multi-owner guard: only free if this is the last owning Fusion. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + if (!v->removeOwningFusion(const_cast(fusion))) { + return false; // other Fusions still own this Val — keep alive + } + vals_.erase(v.get()); + return true; // last owner gone → Val freed + } + return false; + }); + per_fusion_vals_.erase(vals_it); + } } std::deque IrContainer::deterministicValsOwnedBy( diff --git a/csrc/ir/container.h b/csrc/ir/container.h index a9555ae3305..c8824f8c954 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -81,6 +81,7 @@ class IrContainer { protected: // Let Fusion access IrContainer internals (mutex_, fields, Impl helpers) friend class Fusion; + friend class IrCloner; mutable std::shared_mutex mutex_; diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 22484f1b859..cb630ad6122 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -157,7 +157,13 @@ void IterVisitor::traverseBetween( if (to.empty()) { return; } - Fusion* fusion = to.front()->fusion(); + // Use the active FusionGuard rather than deriving the Fusion from a val. + // This avoids calling fusion() on shared scalars whose ir_container_ + // points to the original creator Fusion, not the current traversal target. + Fusion* fusion = FusionGuard::getCurFusion(); + if (fusion == nullptr) { + fusion = to.front()->fusion(); + } FusionGuard fg(fusion); std::unordered_set visited; @@ -468,7 +474,10 @@ void BackwardVisitor::traverseTo( if (from.empty()) { return; } - Fusion* fusion = from.front()->fusion(); + Fusion* fusion = FusionGuard::getCurFusion(); + if (fusion == nullptr) { + fusion = from.front()->fusion(); + } FusionGuard fg(fusion); // Reset members diff --git a/csrc/runtime/communication_executor.cpp b/csrc/runtime/communication_executor.cpp index a87990794ea..20b026388e0 100644 --- a/csrc/runtime/communication_executor.cpp +++ b/csrc/runtime/communication_executor.cpp @@ -54,7 +54,8 @@ void CommunicationExecutor::compile(Fusion* fusion) { FusionProfiler::segment(group_id_).startCompile(); } - host_ir_container_ = std::make_unique(); + host_ir_container_ = + std::make_unique(fusion->ir_container_ptr()); IrCloner cloner = Fusion::copy(fusion, host_ir_container_.get()); if (fusion->isA()) { for (Expr* e : fusion->as()->topLevelExprs()) {