Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,22 @@ struct Fusion::ContainerMutator {
removeExpr(self, e);
}

c->per_fusion_vals_[self].erase(val);

// Multi-owner guard: only free if this is the last owning Fusion.
if (!val->removeOwningFusion(self)) {
// Other Fusions still own this Val — remove from our tracking but
// keep the Val alive in vals_up_ and vals_.
self->invalidateTvsAndUses();
return;
}

auto val_in_deque = std::ranges::find_if(
c->vals_up_,
[val](std::unique_ptr<Val>& val_up) { return val_up.get() == val; });
NVF_ERROR(
val_in_deque != c->vals_up_.end(),
"Wanted to remove a value but its unique ptr is missing.");
c->per_fusion_vals_[self].erase(val);
c->vals_.erase(val);
c->vals_up_.erase(val_in_deque);

Expand All @@ -137,6 +146,11 @@ struct Fusion::ContainerMutator {
c->vals_.insert(val);
c->per_fusion_vals_[self].insert(val);
val->setName(IrContainerPasskey(), self->getValName(val->vtype()));

// Seed owning_fusions_ with the registering Fusion (original creator).
if (val->owning_fusions_.empty()) {
val->owning_fusions_.push_back(self);
}
}

static void registerExpr(Fusion* self, Expr* expr) {
Expand All @@ -159,7 +173,9 @@ struct Fusion::ContainerMutator {
c->assertInContainerImpl(input, "Input to expr is invalid, ");
if (input->isA<TensorView>()) {
self->invalidateTvsAndUses();
} else {
} else if (!input->isShared()) {
// Don't track uses on shared scalars — their uses_ would accumulate
// Exprs from multiple Fusions, causing cross-Fusion DAG leakage.
input->addUse(expr);
}
}
Expand Down Expand Up @@ -212,7 +228,8 @@ struct Fusion::ContainerMutator {
Expr* e = c->exprs_up_.back().get();
NVF_ERROR(
c->per_fusion_exprs_[self].count(e) > 0,
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
"removeStatementsCreatedAfter: tail expr belongs to another "
"Fusion");
for (Val* out : e->outputs()) {
out->setDefinition(nullptr);
}
Expand All @@ -230,6 +247,11 @@ struct Fusion::ContainerMutator {
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
nullOutShortcutIfNeeded(self, v);
c->per_fusion_vals_[self].erase(v);
if (!v->removeOwningFusion(self)) {
// Shared val — other Fusions still own it. Can't pop from deque
// tail since the val must stay alive. Move it out of the way.
break;
}
c->vals_.erase(v);
c->vals_up_.pop_back();
}
Expand Down Expand Up @@ -276,6 +298,10 @@ struct Fusion::ContainerMutator {
// self's new val — remove (null shortcut cache pointer if applicable)
nullOutShortcutIfNeeded(self, v);
c->per_fusion_vals_[self].erase(v);
// Multi-owner guard: only free if last owner.
if (!v->removeOwningFusion(self)) {
return false; // shared val — other Fusions still own it
}
c->vals_.erase(v);
return true;
});
Expand Down Expand Up @@ -457,10 +483,16 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
}

// Wire up definitions and uses on cloned vals in deterministic order
// to ensure exprs are inserted into exprs_up_ deterministically
// to ensure exprs are inserted into exprs_up_ deterministically.
// Skip reused vals (shared scalars) — their definition/uses belong to
// the source Fusion and must not be overwritten.
for (auto val : from->deterministic_vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
auto* cloned = ir_cloner.clone(val);
if (cloned == val) {
continue; // reused (shared scalar) — don't rewire
}
cloned->setDefinition(ir_cloner.clone(val->definition_));
cloned->setUses(ir_cloner.clone(val->uses_));
}

// Sync per-Fusion name counters from source to dest.
Expand Down
1 change: 1 addition & 0 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ class NVF_API Fusion : public PolymorphicBase {
friend SegmentCandidateFinder;
friend SegmentedFusion;
friend class TranslateApplicableWelford;
friend class IrCloner;
friend Val;

//! Constructor that shares an existing container. Creates an empty Fusion
Expand Down
4 changes: 2 additions & 2 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1801,8 +1801,8 @@ std::pair<IrCloner, std::unique_ptr<Fusion>> SegmentedFusion::makeFusion(
SegmentedGroup* sg) const {
// TODO Optimize cloning step by only copying values and expressions between
// the fusion segment's inputs and outputs.
auto fusion_segment = std::unique_ptr<Fusion>(
new Fusion(completeFusion()->ir_container_ptr()));
auto fusion_segment =
std::unique_ptr<Fusion>(new Fusion(completeFusion()->ir_container_ptr()));

IrCloner complete_to_segment_map =
Fusion::copy(completeFusion(), fusion_segment.get());
Expand Down
31 changes: 31 additions & 0 deletions csrc/ir/base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,33 @@ kir::Kernel* Statement::kernel() const {

NVFUSER_DEFINE_CLONE(Val)

bool Val::isOwnedBy(const Fusion* f) const {
return std::find(owning_fusions_.begin(), owning_fusions_.end(), f) !=
owning_fusions_.end();
}

void Val::addOwningFusion(Fusion* f) {
if (!isOwnedBy(f)) {
bool was_unshared = !isShared();
owning_fusions_.push_back(f);
if (was_unshared && isShared()) {
// Transitioning to shared: clear uses_ so it doesn't hold stale
// Expr pointers from the original creator Fusion. registerExpr()
// already skips addUse() for shared scalars, so uses_ would never
// be updated again — keeping it non-empty would be misleading.
uses_.clear();
}
}
}

bool Val::removeOwningFusion(Fusion* f) {
auto it = std::find(owning_fusions_.begin(), owning_fusions_.end(), f);
if (it != owning_fusions_.end()) {
owning_fusions_.erase(it);
}
return owning_fusions_.empty();
}

void Val::addDependency(Val* dependency) {
NVF_ERROR(dependency != nullptr);

Expand All @@ -100,6 +127,10 @@ void Val::addDependency(Val* dependency) {
}

const std::vector<Expr*>& Val::uses() const {
// Shared scalars return an empty uses_ (cleared in addOwningFusion).
// registerExpr() skips addUse() for shared scalars, so uses_ is never
// populated after sharing. Callers get an empty vector, which is
// correct — per_fusion_exprs_ should be used for Fusion-specific traversal.
if (vtype_ == ValType::TensorView) {
if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) {
fusion()->resetTvUses();
Expand Down
21 changes: 21 additions & 0 deletions csrc/ir/base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,24 @@ class NVF_API Val : public Statement {
definition_ = expr;
}

// Multi-owner tracking for Phase 3 scalar sharing.
// owning_fusions_[0] = original creator (by convention, set at registration).
// Grows to 2+ only for shared scalars (IrCloner reuse path).
bool isShared() const {
return owning_fusions_.size() > 1;
}

bool isOwnedBy(const Fusion* f) const;

void addOwningFusion(Fusion* f);

// Remove an owning Fusion. Returns true if this was the last owner.
bool removeOwningFusion(Fusion* f);

const std::vector<Fusion*>& owningFusions() const {
return owning_fusions_;
}

NVFUSER_DECLARE_CLONE

protected:
Expand Down Expand Up @@ -454,6 +472,9 @@ class NVF_API Val : public Statement {
// welford operations.
DataType dtype_;

// Tracks all Fusions that own this Val. Seeded at registration.
std::vector<Fusion*> owning_fusions_;

// Following is managed by Fusion and can change.
bool is_fusion_input_ = false;
bool is_fusion_output_ = false;
Expand Down
45 changes: 35 additions & 10 deletions csrc/ir/cloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,46 @@ Statement* IrCloner::clone(const Statement* statement) {
return nullptr;
}

// Have we already cloned this node?
// Step 1: Cache check — already cloned or reused
const auto it = clones_map_.find(statement);
if (it != clones_map_.end()) {
return it->second;
} else {
auto new_node = handle(statement);

// The base cloning constructor (Statement) should have
// registered the new node. Failure to do so indicates
// that something went horribly wrong.
NVF_ERROR(new_node != nullptr);
NVF_ERROR(clones_map_[statement] == new_node);
}

return new_node;
// Step 2: Scalar reuse — share leaf scalars across Fusions
if (statement->isVal()) {
const Val* val = statement->as<Val>();
if (val->isScalar() && val->definition() == nullptr &&
!val->value().hasValue() && !val->isFusionInput() &&
!val->uses().empty()) {
Fusion* src_fusion = val->container();
if (src_fusion != nullptr && src_fusion != ir_container_ &&
src_fusion->ir_container() == ir_container_->ir_container()) {
Val* reused = const_cast<Val*>(val);

// (a) Cache so downstream Expr clones resolve this input
clones_map_[statement] = reused;

// (b) Register with dest Fusion's per-Fusion tracking
auto* c = ir_container_->ir_container();
{
std::unique_lock lock(c->mutex_);
c->per_fusion_vals_[ir_container_].insert(reused);
}

// (c) Track ownership for lifetime management
reused->addOwningFusion(ir_container_);

return reused;
}
}
}

// Step 3: Full clone (unchanged)
auto new_node = handle(statement);
NVF_ERROR(new_node != nullptr);
NVF_ERROR(clones_map_[statement] == new_node);
return new_node;
}

void IrCloner::registerClone(const Statement* src, Statement* clone) {
Expand Down
45 changes: 32 additions & 13 deletions csrc/ir/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ int64_t IrContainer::numVals() const noexcept {
void IrContainer::addFusion(Fusion* fusion) {
std::unique_lock lock(mutex_);
sharing_fusions_.insert(fusion);
per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs during concurrent val/expr registration
per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs
// during concurrent val/expr registration
per_fusion_exprs_[fusion];
}

Expand Down Expand Up @@ -220,31 +221,49 @@ void IrContainer::transferStatementOwnership(

void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) {
std::unique_lock lock(mutex_);
auto vals_it = per_fusion_vals_.find(fusion);
if (vals_it != per_fusion_vals_.end()) {
const auto& owned = vals_it->second;
std::erase_if(vals_up_, [&](const std::unique_ptr<Val>& v) {
if (owned.count(v.get()) > 0) {
vals_.erase(v.get());
return true;
}
return false;
});
per_fusion_vals_.erase(vals_it);
}

// Process Exprs FIRST — clean up uses_/definition_ pointers on Vals
// before freeing Exprs. This prevents dangling pointers in shared
// scalars' uses_ vectors (shared scalars survive Val cleanup via the
// multi-owner guard but their uses_ would reference freed Exprs).
auto exprs_it = per_fusion_exprs_.find(fusion);
if (exprs_it != per_fusion_exprs_.end()) {
const auto& owned = exprs_it->second;
std::erase_if(exprs_up_, [&](const std::unique_ptr<Expr>& e) {
if (owned.count(e.get()) > 0) {
for (Val* out : e->outputs()) {
out->setDefinition(nullptr);
}
for (Val* inp : e->inputs()) {
inp->removeUse(e.get());
}
exprs_.erase(e.get());
return true;
}
return false;
});
per_fusion_exprs_.erase(exprs_it);
}

// Then Vals — shared vals survive via multi-owner guard, now with
// clean uses_ (dangling Expr pointers already removed above).
auto vals_it = per_fusion_vals_.find(fusion);
if (vals_it != per_fusion_vals_.end()) {
const auto& owned = vals_it->second;
std::erase_if(vals_up_, [&](const std::unique_ptr<Val>& v) {
if (owned.count(v.get()) > 0) {
// Multi-owner guard: only free if this is the last owning Fusion.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
if (!v->removeOwningFusion(const_cast<Fusion*>(fusion))) {
return false; // other Fusions still own this Val — keep alive
}
vals_.erase(v.get());
return true; // last owner gone → Val freed
}
return false;
});
per_fusion_vals_.erase(vals_it);
}
}

std::deque<Val*> IrContainer::deterministicValsOwnedBy(
Expand Down
1 change: 1 addition & 0 deletions csrc/ir/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class IrContainer {
protected:
// Let Fusion access IrContainer internals (mutex_, fields, Impl helpers)
friend class Fusion;
friend class IrCloner;

mutable std::shared_mutex mutex_;

Expand Down
13 changes: 11 additions & 2 deletions csrc/iter_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,13 @@ void IterVisitor::traverseBetween(
if (to.empty()) {
return;
}
Fusion* fusion = to.front()->fusion();
// Use the active FusionGuard rather than deriving the Fusion from a val.
// This avoids calling fusion() on shared scalars whose ir_container_
// points to the original creator Fusion, not the current traversal target.
Fusion* fusion = FusionGuard::getCurFusion();
if (fusion == nullptr) {
fusion = to.front()->fusion();
}
FusionGuard fg(fusion);

std::unordered_set<Statement*> visited;
Expand Down Expand Up @@ -468,7 +474,10 @@ void BackwardVisitor::traverseTo(
if (from.empty()) {
return;
}
Fusion* fusion = from.front()->fusion();
Fusion* fusion = FusionGuard::getCurFusion();
if (fusion == nullptr) {
fusion = from.front()->fusion();
}
FusionGuard fg(fusion);

// Reset members
Expand Down
Loading