Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 88 additions & 50 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,19 @@ struct Fusion::ContainerMutator {
}
}

static int64_t numValsExcludingShortcuts(const Fusion* self) noexcept {
auto* c = self->ir_container();
// Use direct field access. Avoids re-entering valsOwnedBy() which acquires
// shared_lock.
const auto it = c->per_fusion_vals_.find(self);
int64_t count = it != c->per_fusion_vals_.end()
? static_cast<int64_t>(it->second.size())
: 0;
count -= (self->zero_val_ != nullptr) + (self->one_val_ != nullptr) +
(self->true_val_ != nullptr) + (self->false_val_ != nullptr) +
(self->magic_zero_val_ != nullptr);
return count;
// Null out self's shortcut-val pointer cache if v is one of them.
static void nullOutShortcutIfNeeded(Fusion* self, Val* v) {
if (v == self->zero_val_) {
self->zero_val_ = nullptr;
} else if (v == self->one_val_) {
self->one_val_ = nullptr;
} else if (v == self->true_val_) {
self->true_val_ = nullptr;
} else if (v == self->false_val_) {
self->false_val_ = nullptr;
} else if (v == self->magic_zero_val_) {
self->magic_zero_val_ = nullptr;
}
}

static void removeStatementsCreatedAfter(
Expand All @@ -204,42 +205,83 @@ struct Fusion::ContainerMutator {
int64_t num_vals_before) {
auto* c = self->ir_container();

// Remove expressions before values because we need to change Val::uses_.
while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) {
// Pop from global deque back — statements created by this Fusion during
// the guard scope are at the tail (LIFO invariant).
Expr* e = c->exprs_up_.back().get();
NVF_ERROR(
c->per_fusion_exprs_[self].count(e) > 0,
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
for (Val* in : e->inputs()) {
in->removeUse(e);
// Use direct field access — hasMultipleFusions() acquires shared_lock which
// deadlocks since the caller already holds unique_lock on mutex_.
if (c->sharing_fusions_.size() <= 1) {
// Fast path: single Fusion owns this container, so the LIFO invariant
// holds — self's newest statements are always at the global deque tail.
// Remove expressions before values because we need to change Val::uses_.
while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) {
Expr* e = c->exprs_up_.back().get();
NVF_ERROR(
c->per_fusion_exprs_[self].count(e) > 0,
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
for (Val* out : e->outputs()) {
out->setDefinition(nullptr);
}
for (Val* in : e->inputs()) {
in->removeUse(e);
}
c->per_fusion_exprs_[self].erase(e);
c->exprs_.erase(e);
c->exprs_up_.pop_back();
}
c->per_fusion_exprs_[self].erase(e);
c->exprs_.erase(e);
c->exprs_up_.pop_back();
}

while (numValsExcludingShortcuts(self) > num_vals_before) {
Val* v = c->vals_up_.back().get();
NVF_ERROR(
c->per_fusion_vals_[self].count(v) > 0,
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
// Null out shortcut caches if they point to vals about to be destroyed
if (v == self->zero_val_) {
self->zero_val_ = nullptr;
} else if (v == self->one_val_) {
self->one_val_ = nullptr;
} else if (v == self->true_val_) {
self->true_val_ = nullptr;
} else if (v == self->false_val_) {
self->false_val_ = nullptr;
} else if (v == self->magic_zero_val_) {
self->magic_zero_val_ = nullptr;
while (std::ssize(c->per_fusion_vals_[self]) > num_vals_before) {
Val* v = c->vals_up_.back().get();
NVF_ERROR(
c->per_fusion_vals_[self].count(v) > 0,
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
nullOutShortcutIfNeeded(self, v);
c->per_fusion_vals_[self].erase(v);
c->vals_.erase(v);
c->vals_up_.pop_back();
}
c->per_fusion_vals_[self].erase(v);
c->vals_.erase(v);
c->vals_up_.pop_back();
} else {
// Slow path: shared container — other Fusions' statements may be
// interleaved at the tail of the global deques. Use std::erase_if
// (C++20) to scan forward: skip the first num_before of self's
// statements (old, to keep), then erase the remainder (added during
// the guard scope). Entered whenever the container is shared,
// regardless of success or failure; if no new statements were added
// the scan completes trivially. O(total statements in container).
int64_t exprs_kept = 0;
std::erase_if(c->exprs_up_, [&](const std::unique_ptr<Expr>& e_up) {
Expr* e = e_up.get();
if (c->per_fusion_exprs_[self].count(e) == 0) {
return false; // belongs to another Fusion — keep
}
if (exprs_kept < num_exprs_before) {
++exprs_kept;
return false; // self's old expr — keep
}
// self's new expr — remove (clean up uses and index maps first)
for (Val* out : e->outputs()) {
out->setDefinition(nullptr);
}
for (Val* in : e->inputs()) {
in->removeUse(e);
}
c->per_fusion_exprs_[self].erase(e);
c->exprs_.erase(e);
return true;
});

int64_t vals_kept = 0;
std::erase_if(c->vals_up_, [&](const std::unique_ptr<Val>& v_up) {
Val* v = v_up.get();
if (c->per_fusion_vals_[self].count(v) == 0) {
return false; // belongs to another Fusion — keep
}
if (vals_kept < num_vals_before) {
++vals_kept;
return false; // self's old val — keep
}
// self's new val — remove (null shortcut cache pointer if applicable)
nullOutShortcutIfNeeded(self, v);
c->per_fusion_vals_[self].erase(v);
c->vals_.erase(v);
return true;
});
}
}
};
Expand Down Expand Up @@ -622,10 +664,6 @@ void Fusion::removeStatementsCreatedAfter(
this, num_exprs_before, num_vals_before);
}

int64_t Fusion::numValsExcludingShortcuts() const noexcept {
return ContainerMutator::numValsExcludingShortcuts(this);
}

void Fusion::addInput(Val* input) {
assertInContainer(input, "Cannot register input ");

Expand Down
7 changes: 0 additions & 7 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,6 @@ class NVF_API Fusion : public PolymorphicBase {
return std::ssize(ir_container()->valsOwnedBy(this));
}

//! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.).
//! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but
//! since they're singletons that should persist across StatementGuard scopes,
//! this count excludes them so the LIFO pop-back in
//! removeStatementsCreatedAfter correctly skips over them.
int64_t numValsExcludingShortcuts() const noexcept;

// Shortcut values (frequently used constants)
Val* zeroVal();
Val* oneVal();
Expand Down
2 changes: 1 addition & 1 deletion csrc/statement_guard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion)
return fusion;
}()),
prev_num_exprs_(fusion_->numExprs()),
prev_num_vals_(fusion_->numValsExcludingShortcuts()) {}
prev_num_vals_(fusion_->numVals()) {}

StatementGuard::~StatementGuard() {
fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);
Expand Down
Loading