diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 1b0b157b6a6..f01da35c663 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -184,18 +184,19 @@ struct Fusion::ContainerMutator { } } - static int64_t numValsExcludingShortcuts(const Fusion* self) noexcept { - auto* c = self->ir_container(); - // Use direct field access. Avoids re-entering valsOwnedBy() which acquires - // shared_lock. - const auto it = c->per_fusion_vals_.find(self); - int64_t count = it != c->per_fusion_vals_.end() - ? static_cast(it->second.size()) - : 0; - count -= (self->zero_val_ != nullptr) + (self->one_val_ != nullptr) + - (self->true_val_ != nullptr) + (self->false_val_ != nullptr) + - (self->magic_zero_val_ != nullptr); - return count; + // Null out self's shortcut-val pointer cache if v is one of them. + static void nullOutShortcutIfNeeded(Fusion* self, Val* v) { + if (v == self->zero_val_) { + self->zero_val_ = nullptr; + } else if (v == self->one_val_) { + self->one_val_ = nullptr; + } else if (v == self->true_val_) { + self->true_val_ = nullptr; + } else if (v == self->false_val_) { + self->false_val_ = nullptr; + } else if (v == self->magic_zero_val_) { + self->magic_zero_val_ = nullptr; + } } static void removeStatementsCreatedAfter( @@ -204,42 +205,83 @@ struct Fusion::ContainerMutator { int64_t num_vals_before) { auto* c = self->ir_container(); - // Remove expressions before values because we need to change Val::uses_. - while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) { - // Pop from global deque back — statements created by this Fusion during - // the guard scope are at the tail (LIFO invariant). - Expr* e = c->exprs_up_.back().get(); - NVF_ERROR( - c->per_fusion_exprs_[self].count(e) > 0, - "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); - for (Val* in : e->inputs()) { - in->removeUse(e); + // Use direct field access — hasMultipleFusions() acquires shared_lock which + // deadlocks since the caller already holds unique_lock on mutex_. + if (c->sharing_fusions_.size() <= 1) { + // Fast path: single Fusion owns this container, so the LIFO invariant + // holds — self's newest statements are always at the global deque tail. + // Remove expressions before values because we need to change Val::uses_. + while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) { + Expr* e = c->exprs_up_.back().get(); + NVF_ERROR( + c->per_fusion_exprs_[self].count(e) > 0, + "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); + for (Val* out : e->outputs()) { + out->setDefinition(nullptr); + } + for (Val* in : e->inputs()) { + in->removeUse(e); + } + c->per_fusion_exprs_[self].erase(e); + c->exprs_.erase(e); + c->exprs_up_.pop_back(); } - c->per_fusion_exprs_[self].erase(e); - c->exprs_.erase(e); - c->exprs_up_.pop_back(); - } - - while (numValsExcludingShortcuts(self) > num_vals_before) { - Val* v = c->vals_up_.back().get(); - NVF_ERROR( - c->per_fusion_vals_[self].count(v) > 0, - "removeStatementsCreatedAfter: tail val belongs to another Fusion"); - // Null out shortcut caches if they point to vals about to be destroyed - if (v == self->zero_val_) { - self->zero_val_ = nullptr; - } else if (v == self->one_val_) { - self->one_val_ = nullptr; - } else if (v == self->true_val_) { - self->true_val_ = nullptr; - } else if (v == self->false_val_) { - self->false_val_ = nullptr; - } else if (v == self->magic_zero_val_) { - self->magic_zero_val_ = nullptr; + while (std::ssize(c->per_fusion_vals_[self]) > num_vals_before) { + Val* v = c->vals_up_.back().get(); + NVF_ERROR( + c->per_fusion_vals_[self].count(v) > 0, + "removeStatementsCreatedAfter: tail val belongs to another Fusion"); + nullOutShortcutIfNeeded(self, v); + c->per_fusion_vals_[self].erase(v); + c->vals_.erase(v); + c->vals_up_.pop_back(); } - c->per_fusion_vals_[self].erase(v); - c->vals_.erase(v); - c->vals_up_.pop_back(); + } else { + // Slow path: shared container — other Fusions' statements may be + // interleaved at the tail of the global deques. Use std::erase_if + // (C++20) to scan forward: skip the first num_before of self's + // statements (old, to keep), then erase the remainder (added during + // the guard scope). Entered whenever the container is shared, + // regardless of success or failure; if no new statements were added + // the scan completes trivially. O(total statements in container). + int64_t exprs_kept = 0; + std::erase_if(c->exprs_up_, [&](const std::unique_ptr& e_up) { + Expr* e = e_up.get(); + if (c->per_fusion_exprs_[self].count(e) == 0) { + return false; // belongs to another Fusion — keep + } + if (exprs_kept < num_exprs_before) { + ++exprs_kept; + return false; // self's old expr — keep + } + // self's new expr — remove (clean up uses and index maps first) + for (Val* out : e->outputs()) { + out->setDefinition(nullptr); + } + for (Val* in : e->inputs()) { + in->removeUse(e); + } + c->per_fusion_exprs_[self].erase(e); + c->exprs_.erase(e); + return true; + }); + + int64_t vals_kept = 0; + std::erase_if(c->vals_up_, [&](const std::unique_ptr& v_up) { + Val* v = v_up.get(); + if (c->per_fusion_vals_[self].count(v) == 0) { + return false; // belongs to another Fusion — keep + } + if (vals_kept < num_vals_before) { + ++vals_kept; + return false; // self's old val — keep + } + // self's new val — remove (null shortcut cache pointer if applicable) + nullOutShortcutIfNeeded(self, v); + c->per_fusion_vals_[self].erase(v); + c->vals_.erase(v); + return true; + }); } } }; @@ -622,10 +664,6 @@ void Fusion::removeStatementsCreatedAfter( this, num_exprs_before, num_vals_before); } -int64_t Fusion::numValsExcludingShortcuts() const noexcept { - return ContainerMutator::numValsExcludingShortcuts(this); -} - void Fusion::addInput(Val* input) { assertInContainer(input, "Cannot register input "); diff --git a/csrc/fusion.h b/csrc/fusion.h index 8d0bc3ced3c..fdeb588d9f1 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -556,13 +556,6 @@ class NVF_API Fusion : public PolymorphicBase { return std::ssize(ir_container()->valsOwnedBy(this)); } - //! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.). - //! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but - //! since they're singletons that should persist across StatementGuard scopes, - //! this count excludes them so the LIFO pop-back in - //! removeStatementsCreatedAfter correctly skips over them. - int64_t numValsExcludingShortcuts() const noexcept; - // Shortcut values (frequently used constants) Val* zeroVal(); Val* oneVal(); diff --git a/csrc/statement_guard.cpp b/csrc/statement_guard.cpp index 15a3b4159c3..4575bb59076 100644 --- a/csrc/statement_guard.cpp +++ b/csrc/statement_guard.cpp @@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion) return fusion; }()), prev_num_exprs_(fusion_->numExprs()), - prev_num_vals_(fusion_->numValsExcludingShortcuts()) {} + prev_num_vals_(fusion_->numVals()) {} StatementGuard::~StatementGuard() { fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);