Skip to content

Propagate stream in loop irrespective of device mesh#5363

Merged
Priya2698 merged 8 commits intomainfrom
pm/propagate_stream_bug
Oct 14, 2025
Merged

Propagate stream in loop irrespective of device mesh#5363
Priya2698 merged 8 commits intomainfrom
pm/propagate_stream_bug

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Oct 9, 2025

When filtering the reference inputs, inputs without device mesh were removed. This caused fusions with only stream-parallel tensorviews to skip propagation.

@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Oct 9, 2025

Review updated until commit f356951

Description

  • Propagate stream parallelization in loop domains

  • Skip sharding propagation for scatter op outputs

  • Sort TensorViews by device/stream dimensions

  • Add tests for stream propagation


Changes walkthrough 📝

Relevant files
Enhancement
propagate_shardings.cpp
Enhance sharding propagation with stream support                 

csrc/preseg_passes/propagate_shardings.cpp

  • Introduced sortTvsByParallelDims to sort TensorViews by device/stream
    dimensions
  • Updated input/output ordering to consider both device and stream
    dimensions
  • Added check to skip sharding propagation for scatter op outputs
  • Removed redundant null and device mesh checks
  • +45/-63 
    Tests
    test_stream.cpp
    Add tests for stream propagation                                                 

    tests/cpp/test_stream.cpp

  • Added ForwardPropagation test for stream parallelization
  • Added BackwardPropagation test for stream sharding
  • Updated includes to use all_ops.h and propagate_shardings.h
  • +45/-1   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function sortTvsByParallelDims includes both device and stream dimensions in its count, but the previous version sortTvsByDeviceDims only considered device dimensions. This change may alter the priority ordering of TensorViews during sharding propagation, potentially affecting correctness when stream and device dimensions interact.

    auto num_parallel_dims = [](TensorView* tv) {
      return std::count_if(
          tv->getLoopDomain().begin(),
          tv->getLoopDomain().end(),
          [](IterDomain* id) {
            return !id->isReduction() && (id->isStream() || id->isDeviceDim());
          });
    };
    
    std::vector<TensorView*> tvs_vec(tvs.begin(), tvs.end());
    
    std::ranges::stable_sort(tvs_vec, [&num_parallel_dims](auto a, auto b) {
      return std::make_pair(num_parallel_dims(a), a->getDeviceMesh().rank()) >
          std::make_pair(num_parallel_dims(b), b->getDeviceMesh().rank());
    });
    
    return tvs_vec;
    Logic Gap

    The removal of NVF_ERROR checks for device mesh presence in getOrderedReferenceInputs may allow propagation from inputs without a device mesh, leading to undefined behavior during sharding propagation if such inputs are encountered.

    for (auto* ref_input : reference_inputs) {
      NVF_ERROR(ref_input != nullptr);
    
      // Consider out [M, N] = linear (inp [M, K], weight (N,
      // K)) with inp sharded on M ([DIDx(d), M/d, K]) and weight sharded on N
      // ([DIDy(d), N/d, K]). We propagate from weights first, so the output
    Missing Validation

    The backward propagation loop skips user-sharded tensors, but does not validate whether the reference output has a device mesh before use, which could lead to dereferencing a null or invalid mesh in transformLoopDomain.

    for (Expr* expr : exprs | std::views::reverse) {
      const auto& outputs = ir_utils::filterByType<TensorView>(expr->outputs());
      if (outputs.empty()) {
        continue;
      }
      // All outputs of an expression (Welford, SDPA) should be uniformly sharded.
      // We pick the most parallel output as the reference.
      // This is to avoid picking seed/offset tvs in SDPA.
      std::vector<TensorView*> sorted_outputs = sortTvsByParallelDims(outputs);
      TensorView* ref_output = sorted_outputs.front();

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 changed the title Propagate stream irrespective of device mesh Propagate stream in loop irrespective of device mesh Oct 10, 2025
    extent_val = promoteSize(extent_val, id->extent());
    if (iter_type.has_value()) {
    iter_type = promoteIterType(iter_type.value(), id->getIterType());
    } else if (id->isGatherScatter()) {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    PR #5365

    @Priya2698 Priya2698 marked this pull request as ready for review October 10, 2025 02:12
    @Priya2698 Priya2698 requested a review from wujingyue October 10, 2025 02:12
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 merged commit de077d7 into main Oct 14, 2025
    64 of 65 checks passed
    @Priya2698 Priya2698 deleted the pm/propagate_stream_bug branch October 14, 2025 00:17
    Priya2698 added a commit that referenced this pull request Oct 16, 2025
    Issue #5309 
    Unlike device parallelization, a stream parallel tensorview (in loop)
    may or may not have a stream-parallel allocation domain.
    
    We propagate based on the following:
    1. If it is a device parallel type -> always propagate
    2. If it is a fusion input or output -> id is not stream parallelized
    3. If the stream ID in a tensorview is not mapped to stream ID in all of
    its consumers -> id is not stream parallelized
    
    For cases like:
    https://github.com/NVIDIA/Fuser/blob/f8e84e52296cdecd318dd2ce904139616d7bd434/tests/cpp/test_overlap.cpp#L155,
    we want to start with replicating Stream-parallel ID, that is the
    allocation is not parallelized. However, this ID will appear in the
    logical domain due to rfactor and with the current contract, be
    allocated fully regardless of parallelization. So I am not making this a
    condition in the pass, yet.
    
    This can be changed in future when we need.
    
    Depends on #5363
    
    ---------
    
    Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    When filtering the reference inputs, inputs without device mesh were
    removed. This caused fusions with only stream-parallel tensorviews to
    skip propagation.
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    Issue #5309 
    Unlike device parallelization, a stream parallel tensorview (in loop)
    may or may not have a stream-parallel allocation domain.
    
    We propagate based on the following:
    1. If it is a device parallel type -> always propagate
    2. If it is a fusion input or output -> id is not stream parallelized
    3. If the stream ID in a tensorview is not mapped to stream ID in all of
    its consumers -> id is not stream parallelized
    
    For cases like:
    https://github.com/NVIDIA/Fuser/blob/f8e84e52296cdecd318dd2ce904139616d7bd434/tests/cpp/test_overlap.cpp#L155,
    we want to start with replicating Stream-parallel ID, that is the
    allocation is not parallelized. However, this ID will appear in the
    logical domain due to rfactor and with the current contract, be
    allocated fully regardless of parallelization. So I am not making this a
    condition in the pass, yet.
    
    This can be changed in future when we need.
    
    Depends on #5363
    
    ---------
    
    Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants