From a8a40a05600a57c96d4deb56c405beb7f214becc Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 3 Mar 2026 20:43:17 -0800 Subject: [PATCH 1/6] We need to ensure statements removed are only popped from the owning Fusion, not the entire IrContainer. --- csrc/fusion.cpp | 123 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 89 insertions(+), 34 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 1b0b157b6a6..01fa937d599 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -198,48 +198,103 @@ struct Fusion::ContainerMutator { return count; } + // Null out self's shortcut-val pointer cache if v is one of them. + static void nullOutShortcutIfNeeded(Fusion* self, Val* v) { + if (v == self->zero_val_) { + self->zero_val_ = nullptr; + } else if (v == self->one_val_) { + self->one_val_ = nullptr; + } else if (v == self->true_val_) { + self->true_val_ = nullptr; + } else if (v == self->false_val_) { + self->false_val_ = nullptr; + } else if (v == self->magic_zero_val_) { + self->magic_zero_val_ = nullptr; + } + } + + // Returns true if v is one of self's shortcut singleton vals. These persist + // across StatementGuard scopes and must not be removed on rollback. + static bool isShortcutVal(const Fusion* self, const Val* v) { + return v == self->zero_val_ || v == self->one_val_ || + v == self->true_val_ || v == self->false_val_ || + v == self->magic_zero_val_; + } + static void removeStatementsCreatedAfter( Fusion* self, int64_t num_exprs_before, int64_t num_vals_before) { auto* c = self->ir_container(); - // Remove expressions before values because we need to change Val::uses_. - while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) { - // Pop from global deque back — statements created by this Fusion during - // the guard scope are at the tail (LIFO invariant). - Expr* e = c->exprs_up_.back().get(); - NVF_ERROR( - c->per_fusion_exprs_[self].count(e) > 0, - "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); - for (Val* in : e->inputs()) { - in->removeUse(e); + if (!c->hasMultipleFusions()) { + // Fast path: single Fusion owns this container, so the LIFO invariant + // holds — self's newest statements are always at the global deque tail. + // Remove expressions before values because we need to change Val::uses_. + while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) { + Expr* e = c->exprs_up_.back().get(); + NVF_ERROR( + c->per_fusion_exprs_[self].count(e) > 0, + "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); + for (Val* in : e->inputs()) { + in->removeUse(e); + } + c->per_fusion_exprs_[self].erase(e); + c->exprs_.erase(e); + c->exprs_up_.pop_back(); } - c->per_fusion_exprs_[self].erase(e); - c->exprs_.erase(e); - c->exprs_up_.pop_back(); - } - - while (numValsExcludingShortcuts(self) > num_vals_before) { - Val* v = c->vals_up_.back().get(); - NVF_ERROR( - c->per_fusion_vals_[self].count(v) > 0, - "removeStatementsCreatedAfter: tail val belongs to another Fusion"); - // Null out shortcut caches if they point to vals about to be destroyed - if (v == self->zero_val_) { - self->zero_val_ = nullptr; - } else if (v == self->one_val_) { - self->one_val_ = nullptr; - } else if (v == self->true_val_) { - self->true_val_ = nullptr; - } else if (v == self->false_val_) { - self->false_val_ = nullptr; - } else if (v == self->magic_zero_val_) { - self->magic_zero_val_ = nullptr; + while (numValsExcludingShortcuts(self) > num_vals_before) { + Val* v = c->vals_up_.back().get(); + NVF_ERROR( + c->per_fusion_vals_[self].count(v) > 0, + "removeStatementsCreatedAfter: tail val belongs to another Fusion"); + nullOutShortcutIfNeeded(self, v); + c->per_fusion_vals_[self].erase(v); + c->vals_.erase(v); + c->vals_up_.pop_back(); } - c->per_fusion_vals_[self].erase(v); - c->vals_.erase(v); - c->vals_up_.pop_back(); + } else { + // Slow path: shared container — other Fusions' statements may be + // interleaved at the tail of the global deques. Use std::erase_if + // (C++20) to scan forward: skip the first num_before of self's + // statements (old, to keep), then erase the remainder (added during + // the guard scope). Only taken on the error/rollback path when + // segment compilation fails; O(total statements in container). + int64_t exprs_kept = 0; + std::erase_if(c->exprs_up_, [&](const std::unique_ptr& e_up) { + Expr* e = e_up.get(); + if (c->per_fusion_exprs_[self].count(e) == 0) { + return false; // belongs to another Fusion — keep + } + if (exprs_kept < num_exprs_before) { + ++exprs_kept; + return false; // self's old expr — keep + } + // self's new expr — remove (clean up uses and index maps first) + for (Val* in : e->inputs()) { + in->removeUse(e); + } + c->per_fusion_exprs_[self].erase(e); + c->exprs_.erase(e); + return true; + }); + + int64_t vals_kept = 0; + std::erase_if(c->vals_up_, [&](const std::unique_ptr& v_up) { + Val* v = v_up.get(); + if (c->per_fusion_vals_[self].count(v) == 0 || + isShortcutVal(self, v)) { + return false; // another Fusion's val, or a persistent shortcut — keep + } + if (vals_kept < num_vals_before) { + ++vals_kept; + return false; // self's old val — keep + } + // self's new val — remove + c->per_fusion_vals_[self].erase(v); + c->vals_.erase(v); + return true; + }); } } }; From 6a0b14ee4cfeb2666a64c341148a6df2edc4a5a3 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 3 Mar 2026 21:45:07 -0800 Subject: [PATCH 2/6] deadlock issue. --- csrc/fusion.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 01fa937d599..c91aa80cced 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -227,7 +227,9 @@ struct Fusion::ContainerMutator { int64_t num_vals_before) { auto* c = self->ir_container(); - if (!c->hasMultipleFusions()) { + // Use direct field access — hasMultipleFusions() acquires shared_lock which + // deadlocks since the caller already holds unique_lock on mutex_. + if (c->sharing_fusions_.size() <= 1) { // Fast path: single Fusion owns this container, so the LIFO invariant // holds — self's newest statements are always at the global deque tail. // Remove expressions before values because we need to change Val::uses_. From 7e9236ec2a0e34f99cde8f64efbddc3b5866f2f9 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Wed, 4 Mar 2026 07:56:33 -0800 Subject: [PATCH 3/6] Address review comments: fix comment accuracy and add setDefinition(nullptr) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix inaccurate comment claiming the slow path is only taken on the error/rollback path — it runs unconditionally whenever the container is shared. Add out->setDefinition(nullptr) for expr outputs before destruction in the slow path, matching removeExpr's behavior and eliminating the transient dangling definition_ pointer. --- csrc/fusion.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index c91aa80cced..7f3a8cc91a4 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -260,8 +260,9 @@ struct Fusion::ContainerMutator { // interleaved at the tail of the global deques. Use std::erase_if // (C++20) to scan forward: skip the first num_before of self's // statements (old, to keep), then erase the remainder (added during - // the guard scope). Only taken on the error/rollback path when - // segment compilation fails; O(total statements in container). + // the guard scope). Entered whenever the container is shared, + // regardless of success or failure; if no new statements were added + // the scan completes trivially. O(total statements in container). int64_t exprs_kept = 0; std::erase_if(c->exprs_up_, [&](const std::unique_ptr& e_up) { Expr* e = e_up.get(); @@ -273,6 +274,9 @@ struct Fusion::ContainerMutator { return false; // self's old expr — keep } // self's new expr — remove (clean up uses and index maps first) + for (Val* out : e->outputs()) { + out->setDefinition(nullptr); + } for (Val* in : e->inputs()) { in->removeUse(e); } From f4d6d2eb7b5ff05a093ab865a3f837249a4f0ace Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Wed, 4 Mar 2026 08:05:54 -0800 Subject: [PATCH 4/6] Also reset definition_ on fast path of removeStatementsCreatedAfter Apply the same setDefinition(nullptr) fix to the fast path for consistency with removeExpr and the slow path. --- csrc/fusion.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 7f3a8cc91a4..8e71f430f81 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -238,6 +238,9 @@ struct Fusion::ContainerMutator { NVF_ERROR( c->per_fusion_exprs_[self].count(e) > 0, "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); + for (Val* out : e->outputs()) { + out->setDefinition(nullptr); + } for (Val* in : e->inputs()) { in->removeUse(e); } From 28857396f014e2323b812c08af1bb9597bbb3085 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Wed, 4 Mar 2026 09:57:04 -0800 Subject: [PATCH 5/6] Fix shortcut-val rollback divergence between fast and slow paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit StatementGuard recorded num_vals_before as numValsExcludingShortcuts(), which subtracts the count of non-null shortcut pointers. When a shortcut val (e.g. oneVal()) is lazily created inside a guard scope, the total count rises by 1 but non-null shortcuts also rises by 1, so numValsExcludingShortcuts stays flat — neither path detected the new shortcut via the count condition. The fast path was masked by LIFO ordering (new shortcuts sit at the deque tail and get popped while removing other new vals). The slow path's unconditional isShortcutVal keep made the bug real: shortcut vals created inside the guard were permanently retained. Fix: record num_vals_before as the total val count (numVals()), making counting uniform. Fast path switches to std::ssize(per_fusion_vals_[self]). Slow path drops the isShortcutVal skip and adds nullOutShortcutIfNeeded when removing, so shortcuts created inside the guard are rolled back and their cache pointers nulled, matching the fast path's behavior. --- csrc/fusion.cpp | 16 +++++++++------- csrc/statement_guard.cpp | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 8e71f430f81..3ea14eb327d 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -213,8 +213,10 @@ struct Fusion::ContainerMutator { } } - // Returns true if v is one of self's shortcut singleton vals. These persist - // across StatementGuard scopes and must not be removed on rollback. + // Returns true if v is one of self's shortcut singleton vals. Shortcuts + // created before a StatementGuard scope are kept on rollback; those created + // inside the scope are rolled back like any other val (nulling the cache + // pointer via nullOutShortcutIfNeeded). static bool isShortcutVal(const Fusion* self, const Val* v) { return v == self->zero_val_ || v == self->one_val_ || v == self->true_val_ || v == self->false_val_ || @@ -248,7 +250,7 @@ struct Fusion::ContainerMutator { c->exprs_.erase(e); c->exprs_up_.pop_back(); } - while (numValsExcludingShortcuts(self) > num_vals_before) { + while (std::ssize(c->per_fusion_vals_[self]) > num_vals_before) { Val* v = c->vals_up_.back().get(); NVF_ERROR( c->per_fusion_vals_[self].count(v) > 0, @@ -291,15 +293,15 @@ struct Fusion::ContainerMutator { int64_t vals_kept = 0; std::erase_if(c->vals_up_, [&](const std::unique_ptr& v_up) { Val* v = v_up.get(); - if (c->per_fusion_vals_[self].count(v) == 0 || - isShortcutVal(self, v)) { - return false; // another Fusion's val, or a persistent shortcut — keep + if (c->per_fusion_vals_[self].count(v) == 0) { + return false; // belongs to another Fusion — keep } if (vals_kept < num_vals_before) { ++vals_kept; return false; // self's old val — keep } - // self's new val — remove + // self's new val — remove (null shortcut cache pointer if applicable) + nullOutShortcutIfNeeded(self, v); c->per_fusion_vals_[self].erase(v); c->vals_.erase(v); return true; diff --git a/csrc/statement_guard.cpp b/csrc/statement_guard.cpp index 15a3b4159c3..4575bb59076 100644 --- a/csrc/statement_guard.cpp +++ b/csrc/statement_guard.cpp @@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion) return fusion; }()), prev_num_exprs_(fusion_->numExprs()), - prev_num_vals_(fusion_->numValsExcludingShortcuts()) {} + prev_num_vals_(fusion_->numVals()) {} StatementGuard::~StatementGuard() { fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_); From e2b2d341bc041907ce581c83467cda2e47e1d3fd Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Wed, 4 Mar 2026 10:11:06 -0800 Subject: [PATCH 6/6] Dead code. --- csrc/fusion.cpp | 28 ---------------------------- csrc/fusion.h | 7 ------- 2 files changed, 35 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 3ea14eb327d..f01da35c663 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -184,20 +184,6 @@ struct Fusion::ContainerMutator { } } - static int64_t numValsExcludingShortcuts(const Fusion* self) noexcept { - auto* c = self->ir_container(); - // Use direct field access. Avoids re-entering valsOwnedBy() which acquires - // shared_lock. - const auto it = c->per_fusion_vals_.find(self); - int64_t count = it != c->per_fusion_vals_.end() - ? static_cast(it->second.size()) - : 0; - count -= (self->zero_val_ != nullptr) + (self->one_val_ != nullptr) + - (self->true_val_ != nullptr) + (self->false_val_ != nullptr) + - (self->magic_zero_val_ != nullptr); - return count; - } - // Null out self's shortcut-val pointer cache if v is one of them. static void nullOutShortcutIfNeeded(Fusion* self, Val* v) { if (v == self->zero_val_) { @@ -213,16 +199,6 @@ struct Fusion::ContainerMutator { } } - // Returns true if v is one of self's shortcut singleton vals. Shortcuts - // created before a StatementGuard scope are kept on rollback; those created - // inside the scope are rolled back like any other val (nulling the cache - // pointer via nullOutShortcutIfNeeded). - static bool isShortcutVal(const Fusion* self, const Val* v) { - return v == self->zero_val_ || v == self->one_val_ || - v == self->true_val_ || v == self->false_val_ || - v == self->magic_zero_val_; - } - static void removeStatementsCreatedAfter( Fusion* self, int64_t num_exprs_before, @@ -688,10 +664,6 @@ void Fusion::removeStatementsCreatedAfter( this, num_exprs_before, num_vals_before); } -int64_t Fusion::numValsExcludingShortcuts() const noexcept { - return ContainerMutator::numValsExcludingShortcuts(this); -} - void Fusion::addInput(Val* input) { assertInContainer(input, "Cannot register input "); diff --git a/csrc/fusion.h b/csrc/fusion.h index 8d0bc3ced3c..fdeb588d9f1 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -556,13 +556,6 @@ class NVF_API Fusion : public PolymorphicBase { return std::ssize(ir_container()->valsOwnedBy(this)); } - //! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.). - //! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but - //! since they're singletons that should persist across StatementGuard scopes, - //! this count excludes them so the LIFO pop-back in - //! removeStatementsCreatedAfter correctly skips over them. - int64_t numValsExcludingShortcuts() const noexcept; - // Shortcut values (frequently used constants) Val* zeroVal(); Val* oneVal();