Skip to content

Segmenter Shared Container Fix#6025

Merged
mdavis36 merged 6 commits intomd/segmenter-container-sharingfrom
md/seg-cont-fix
Mar 4, 2026
Merged

Segmenter Shared Container Fix#6025
mdavis36 merged 6 commits intomd/segmenter-container-sharingfrom
md/seg-cont-fix

Conversation

@mdavis36
Copy link
Collaborator

@mdavis36 mdavis36 commented Mar 4, 2026

Statements cleaned up by statement guard need to be popped from the specific fusion only, not the entire IrContainer

@mdavis36
Copy link
Collaborator Author

mdavis36 commented Mar 4, 2026

!test

@mdavis36
Copy link
Collaborator Author

mdavis36 commented Mar 4, 2026

!test

@mdavis36 mdavis36 marked this pull request as ready for review March 4, 2026 15:25
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR fixes a correctness bug in removeStatementsCreatedAfter when multiple Fusion objects share an IrContainer. The old implementation relied on a LIFO pop-back on the global exprs_up_/vals_up_ deques, which is invalid when statements from different Fusions are interleaved in those deques.

Key changes:

  • Fast path (single owner, sharing_fusions_.size() <= 1): preserves the original LIFO pop-back behaviour, with the LIFO invariant still valid.
  • Slow path (shared container): uses std::erase_if with a counter (exprs_kept / vals_kept) to scan the full deque and selectively remove only self's statements that were created after the guard was activated, correctly skipping interleaved entries from other Fusions.
  • nullOutShortcutIfNeeded helper: extracts the repeated shortcut-pointer-nulling logic and is invoked in both paths, correctly handling the case where a shortcut val (e.g. zero_val_) was lazily created during the guard scope and must be rolled back.
  • numValsExcludingShortcuts removed: StatementGuard now uses numVals() for prev_num_vals_. Shortcut vals that existed before the guard are included in the baseline count and thus protected by the threshold condition; shortcuts created during the guard scope are correctly rolled back via nullOutShortcutIfNeeded.

The erase_if ordering logic (counting only self's deque entries in insertion order to find the boundary between "old" and "new") is correct because deque entries are always appended at the tail in creation order, so "first N of self's entries in deque order" == "self's N oldest entries".

Confidence Score: 4/5

  • The fix is logically sound and directly addresses the stated bug; safe to merge with minor review comments considered.
  • The fast/slow path split is clean and correct. The erase_if counter correctly identifies old vs. new statements even with interleaved multi-Fusion deques. Shortcut-val handling is more uniform and more correct than the old exclusion-based approach. The change from numValsExcludingShortcuts to numVals() is semantically equivalent in the normal case and more correct in the edge case where a shortcut is lazily initialised during the guard scope. Previously-discussed concerns (dangling definition_ between the two erase_if passes, and the misleading comment) have been addressed in prior review threads.
  • No files require special attention; csrc/fusion.cpp contains all substantive logic and has been carefully reviewed.

Important Files Changed

Filename Overview
csrc/fusion.cpp Core fix: introduces fast/slow paths in removeStatementsCreatedAfter; slow path uses std::erase_if to correctly handle interleaved deque entries in shared-container scenarios. Extracts nullOutShortcutIfNeeded as a helper. Logic is sound and correctly handles ordering, shortcut vals, and expression cleanup.
csrc/fusion.h Removes the now-unnecessary numValsExcludingShortcuts() declaration. Trivial cleanup change.
csrc/statement_guard.cpp Changes prev_num_vals_ to use numVals() instead of the removed numValsExcludingShortcuts(). Semantically correct: shortcut vals created before the guard scope are included in num_vals_before, so they are protected by the threshold in both fast and slow paths.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["StatementGuard::~StatementGuard()"] --> B["fusion_->removeStatementsCreatedAfter\n(prev_num_exprs_, prev_num_vals_)"]
    B --> C["Acquire unique_lock on mutex_"]
    C --> D["ContainerMutator::removeStatementsCreatedAfter"]
    D --> E{"sharing_fusions_.size() <= 1?"}

    E -- "Yes (fast path)" --> F["LIFO pop-back from\nglobal deque tail"]
    F --> G["while per_fusion_exprs_size > num_exprs_before:\n  assert tail belongs to self\n  setDefinition(nullptr) on outputs\n  removeUse() on inputs\n  erase from per_fusion_exprs_, exprs_\n  pop_back exprs_up_"]
    G --> H["while per_fusion_vals_size > num_vals_before:\n  assert tail belongs to self\n  nullOutShortcutIfNeeded()\n  erase from per_fusion_vals_, vals_\n  pop_back vals_up_"]

    E -- "No (slow path)" --> I["std::erase_if on exprs_up_\n(deque may have interleaved entries)"]
    I --> J["For each entry:\n  skip if belongs to another Fusion\n  keep if exprs_kept < num_exprs_before\n  else: cleanup + erase + return true"]
    J --> K["std::erase_if on vals_up_"]
    K --> L["For each entry:\n  skip if belongs to another Fusion\n  keep if vals_kept < num_vals_before\n  else: nullOutShortcutIfNeeded + erase + return true"]
Loading

Last reviewed commit: e2b2d34

csrc/fusion.cpp Outdated
Comment on lines +259 to +263
// Slow path: shared container — other Fusions' statements may be
// interleaved at the tail of the global deques. Use std::erase_if
// (C++20) to scan forward: skip the first num_before of self's
// statements (old, to keep), then erase the remainder (added during
// the guard scope). Only taken on the error/rollback path when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states this path is "Only taken on the error/rollback path when segment compilation fails," but StatementGuard::~StatementGuard() unconditionally calls removeStatementsCreatedAfter — it does not distinguish success from failure. The slow path is entered whenever the container is shared (sharing_fusions_.size() > 1), regardless of whether it is a compilation error. On a success path where no new statements were added, the two erase_if calls complete trivially (finding nothing to remove), but the O(n) scan still occurs.

Consider updating the comment to something like:

// Slow path: shared container — other Fusions' statements may be
// interleaved at the tail of the global deques. Use std::erase_if
// (C++20) to scan forward: skip the first num_before of self's
// statements (old, to keep), then erase the remainder (added during
// the guard scope). O(total statements in container); in the common
// case where no statements were added the scan is a no-op.

Comment on lines +275 to +281
// self's new expr — remove (clean up uses and index maps first)
for (Val* in : e->inputs()) {
in->removeUse(e);
}
c->per_fusion_exprs_[self].erase(e);
c->exprs_.erase(e);
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When return true; is returned from the erase_if predicate (line 282), the owning unique_ptr<Expr> is destroyed immediately, freeing the expression's memory. The output Vals of that expression still have their definition_ pointer set (to the now-freed memory) until the subsequent vals_up_ erase_if pass removes them (lines 284–300). Between the two erase_if calls there is nothing that dereferences definition_, so this is not a live bug today, but the fast path has the same gap and there is no assertion/documentation establishing this invariant.

This is consistent with the fast-path behaviour, but worth considering adding out->setDefinition(nullptr) here (as removeExpr does) to make the rollback more defensively correct, or at least a comment noting the transient dangling state is intentional.

…ullptr)

Fix inaccurate comment claiming the slow path is only taken on the
error/rollback path — it runs unconditionally whenever the container
is shared. Add out->setDefinition(nullptr) for expr outputs before
destruction in the slow path, matching removeExpr's behavior and
eliminating the transient dangling definition_ pointer.
@mdavis36
Copy link
Collaborator Author

mdavis36 commented Mar 4, 2026

!test

@github-actions
Copy link

github-actions bot commented Mar 4, 2026

Review updated until commit e2b2d34

Description

  • Fix statement cleanup in shared IrContainer to only remove statements owned by the specific Fusion, not all statements from the deque tail.

  • Add fast/slow path in removeStatementsCreatedAfter: fast path for single Fusion (LIFO pop), slow path for shared container (forward scan with std::erase_if).

  • Fix shortcut-val rollback bug: record num_vals_before as numVals() instead of numValsExcludingShortcuts(), and null shortcut cache pointers in both paths.

  • Add setDefinition(nullptr) for expr outputs in both fast and slow paths to eliminate dangling definition_ pointers.

Changes walkthrough

Relevant files
Bug fix
fusion.cpp
Add fast/slow paths for removeStatementsCreatedAfter         

csrc/fusion.cpp

  • Replace numValsExcludingShortcuts with nullOutShortcutIfNeeded helper.
  • Add fast path (single Fusion) using LIFO pop from deque tail.
  • Add slow path (shared container) using std::erase_if to scan forward
    and remove only self's new statements.
  • Add out->setDefinition(nullptr) for expr outputs before removal in
    both paths.
  • Use std::ssize(per_fusion_vals_[self]) instead of
    numValsExcludingShortcuts in fast path.
  • +88/-50 
    statement_guard.cpp
    Use total val count for statement guard rollback                 

    csrc/statement_guard.cpp

  • Change prev_num_vals_ initialization from numValsExcludingShortcuts()
    to numVals().
  • +1/-1     
    Cleanup
    fusion.h
    Remove numValsExcludingShortcuts declaration                         

    csrc/fusion.h

  • Remove numValsExcludingShortcuts() method declaration and its
    documentation comment.
  • +0/-7     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Potential performance regression

    The "slow path" uses std::erase_if which iterates through the entire global deques
    (exprs_up_, vals_up_) each time StatementGuard cleanup is needed in shared container
    scenarios. This is O(total statements in container) as noted in the code comments.
    For containers with many statements, this could be slow. Consider if there's a more
    efficient data structure that could track per-fusion statement ordering.

    } else {
      // Slow path: shared container — other Fusions' statements may be
      // interleaved at the tail of the global deques. Use std::erase_if
      // (C++20) to scan forward: skip the first num_before of self's
      // statements (old, to keep), then erase the remainder (added during
      // the guard scope). Entered whenever the container is shared,
      // regardless of success or failure; if no new statements were added
      // the scan completes trivially. O(total statements in container).
      int64_t exprs_kept = 0;
      std::erase_if(c->exprs_up_, [&](const std::unique_ptr<Expr>& e_up) {
        Expr* e = e_up.get();
        if (c->per_fusion_exprs_[self].count(e) == 0) {
          return false; // belongs to another Fusion — keep
        }
        if (exprs_kept < num_exprs_before) {
          ++exprs_kept;
          return false; // self's old expr — keep
        }
        // self's new expr — remove (clean up uses and index maps first)
        for (Val* out : e->outputs()) {
          out->setDefinition(nullptr);
        }
        for (Val* in : e->inputs()) {
          in->removeUse(e);
        }
        c->per_fusion_exprs_[self].erase(e);
        c->exprs_.erase(e);
        return true;
      });
    
      int64_t vals_kept = 0;
      std::erase_if(c->vals_up_, [&](const std::unique_ptr<Val>& v_up) {
        Val* v = v_up.get();
        if (c->per_fusion_vals_[self].count(v) == 0) {
          return false; // belongs to another Fusion — keep
        }
        if (vals_kept < num_vals_before) {
          ++vals_kept;
          return false; // self's old val — keep
        }
        // self's new val — remove (null shortcut cache pointer if applicable)
        nullOutShortcutIfNeeded(self, v);
        c->per_fusion_vals_[self].erase(v);
        c->vals_.erase(v);
        return true;
      });
    }
    Correctness of fast path assumption

    The fast path assumes that when sharing_fusions_.size() <= 1, the LIFO invariant holds
    and self's newest statements are at the global deque tail. However, there's a subtle
    race condition: another fusion could be created (making size > 1) between when we check
    and when we access the deques. The code holds unique_lock on mutex_, so this should be
    safe, but worth verifying the locking semantics are correct.

    if (c->sharing_fusions_.size() <= 1) {
      // Fast path: single Fusion owns this container, so the LIFO invariant
      // holds — self's newest statements are always at the global deque tail.
      // Remove expressions before values because we need to change Val::uses_.
      while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) {
        Expr* e = c->exprs_up_.back().get();
        NVF_ERROR(
            c->per_fusion_exprs_[self].count(e) > 0,
            "removeStatementsCreatedAfter: tail expr belongs to another Fusion");
        for (Val* out : e->outputs()) {
          out->setDefinition(nullptr);
        }
        for (Val* in : e->inputs()) {
          in->removeUse(e);
        }
        c->per_fusion_exprs_[self].erase(e);
        c->exprs_.erase(e);
        c->exprs_up_.pop_back();
      }
      while (std::ssize(c->per_fusion_vals_[self]) > num_vals_before) {
        Val* v = c->vals_up_.back().get();
        NVF_ERROR(
            c->per_fusion_vals_[self].count(v) > 0,
            "removeStatementsCreatedAfter: tail val belongs to another Fusion");
        nullOutShortcutIfNeeded(self, v);
        c->per_fusion_vals_[self].erase(v);
        c->vals_.erase(v);
        c->vals_up_.pop_back();
      }
    Input/Output order in cleanup

    In the old code (lines 220-224 in old hunk), inputs were cleaned up before outputs.
    In the new fast path (lines 219-224 in new hunk), outputs are cleaned up first, then
    inputs. In the slow path (lines 258-263), outputs are also cleaned up first. This
    change in order should be verified to not cause any issues with Val::uses_ tracking.

    for (Val* out : e->outputs()) {
      out->setDefinition(nullptr);
    }
    for (Val* in : e->inputs()) {
      in->removeUse(e);
    }

    Apply the same setDefinition(nullptr) fix to the fast path for
    consistency with removeExpr and the slow path.
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Mar 4, 2026

    Additional Comments (1)

    csrc/fusion.cpp, line 306
    Shortcut-val rollback behaviour diverges from the fast path

    The isShortcutVal guard here means that shortcut vals (e.g. zero_val_, magic_zero_val_) are always kept, even if they were created inside the current StatementGuard scope.

    The fast path behaves differently: numValsExcludingShortcuts subtracts only the currently-non-null shortcut pointers, so if a shortcut is at the tail of vals_up_ it will be popped and nullOutShortcutIfNeeded will null the cache pointer.

    Concrete example:

    • Before guard: vals_up_ = [A], zero_val_ = nullptr, num_vals_before = 1
    • During guard: B created, then zero_val_ lazy-initialised → vals_up_ = [A, B, zero_val_]
    • Fast path: pops zero_val_ (nulls cache), then pops B; both are rolled back.
    • Slow path: isShortcutVal returns true for zero_val_, so it is retained; only B is removed.

    This diverges from the design intent stated at line 216–217: "These persist across StatementGuard scopes and must not be removed on rollback." Consider either:

    1. Aligning the slow path to match the fast path (null out + erase the shortcut if needed), or
    2. Aligning the fast path to match the slow path's stated intent (also skip shortcuts via isShortcutVal-style check), and updating the comment on numValsExcludingShortcuts to document the new invariant.

    At minimum, add a brief note in the slow-path comment acknowledging the intentional divergence to prevent future confusion.

    StatementGuard recorded num_vals_before as numValsExcludingShortcuts(),
    which subtracts the count of non-null shortcut pointers. When a shortcut
    val (e.g. oneVal()) is lazily created inside a guard scope, the total
    count rises by 1 but non-null shortcuts also rises by 1, so
    numValsExcludingShortcuts stays flat — neither path detected the new
    shortcut via the count condition. The fast path was masked by LIFO
    ordering (new shortcuts sit at the deque tail and get popped while
    removing other new vals). The slow path's unconditional isShortcutVal
    keep made the bug real: shortcut vals created inside the guard were
    permanently retained.
    
    Fix: record num_vals_before as the total val count (numVals()), making
    counting uniform. Fast path switches to std::ssize(per_fusion_vals_[self]).
    Slow path drops the isShortcutVal skip and adds nullOutShortcutIfNeeded
    when removing, so shortcuts created inside the guard are rolled back and
    their cache pointers nulled, matching the fast path's behavior.
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Mar 4, 2026

    Additional Comments (3)

    csrc/fusion.cpp, line 224
    isShortcutVal is dead code

    isShortcutVal is defined here but is never called anywhere in the codebase (confirmed by a full-repo search). Every callsite that needs to test shortcut membership already calls nullOutShortcutIfNeeded directly. If this predicate is intended for a future use, a comment should explain that; otherwise it should be removed to avoid dead-code confusion.

    
    
    ---
    
    **`csrc/fusion.cpp`, line 262**
    **Fast-path val count now includes shortcuts — stale loop variable type**
    
    The loop condition changed from `numValsExcludingShortcuts(self) > num_vals_before` to `std::ssize(c->per_fusion_vals_[self]) > num_vals_before`, and `num_vals_before` now stores `numVals()` (inclusive of shortcuts). This is semantically correct, but the intermediate type used here differs from the slow path's `int64_t vals_kept` pattern and from the type returned by `std::ssize` (`ptrdiff_t` / `ssize_t`). More importantly, `num_vals_before` is declared as `int64_t` in the function signature while `std::ssize()` returns a signed platform-sized integer — on a 64-bit host these are the same, but the comparison is technically between `ptrdiff_t` and `int64_t`. Consider casting for clarity:
    
    ```suggestion
          while (static_cast<int64_t>(c->per_fusion_vals_[self].size()) >
                 num_vals_before) {
    

    csrc/fusion.h, line 564
    numValsExcludingShortcuts() is now dead code — comment is also stale

    After this PR, StatementGuard calls numVals() instead of numValsExcludingShortcuts(), and a full-repo search finds no remaining callers of this public method outside of fusion.cpp's own delegation. The docstring still references removeStatementsCreatedAfter as its consumer, which is no longer accurate.

    If the method is kept as part of a public API, the comment should be updated to reflect its actual purpose. If it has no remaining callers, it (and the ContainerMutator::numValsExcludingShortcuts static implementation) can be removed to avoid confusion.

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Mar 4, 2026

    Additional Comments (1)

    csrc/fusion.cpp, line 199
    Dead code: numValsExcludingShortcuts is now unreachable

    ContainerMutator::numValsExcludingShortcuts is still defined here, but it is no longer called anywhere in the codebase after this PR:

    • Fusion::numValsExcludingShortcuts() (its only caller) was removed from fusion.h and fusion.cpp.
    • The fast-path loop condition in removeStatementsCreatedAfter was rewritten to use std::ssize(c->per_fusion_vals_[self]) directly.
    • statement_guard.cpp now calls numVals() instead.

    The grep confirms the symbol appears only in fusion.cpp. Consider removing the dead function to keep the file clean:

      // nullOutShortcutIfNeeded replaces the removed numValsExcludingShortcuts helper.
    

    Or simply delete lines 187–199 entirely.

    @mdavis36
    Copy link
    Collaborator Author

    mdavis36 commented Mar 4, 2026

    !test

    @mdavis36 mdavis36 merged commit b1873c8 into md/segmenter-container-sharing Mar 4, 2026
    40 of 44 checks passed
    @mdavis36 mdavis36 deleted the md/seg-cont-fix branch March 4, 2026 19:48
    mdavis36 added a commit that referenced this pull request Mar 5, 2026
    Statements cleaned up by statement guard need to be popped from the
    specific fusion only, not the entire IrContainer.
    mdavis36 added a commit that referenced this pull request Mar 5, 2026
    Statements cleaned up by statement guard need to be popped from the
    specific fusion only, not the entire IrContainer.
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant