Skip to content

Propagate Stream parallel type in allocation#5353

Merged
Priya2698 merged 21 commits intomainfrom
pm/alloc_stream
Oct 16, 2025
Merged

Propagate Stream parallel type in allocation#5353
Priya2698 merged 21 commits intomainfrom
pm/alloc_stream

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Oct 8, 2025

Issue #5309
Unlike device parallelization, a stream parallel tensorview (in loop) may or may not have a stream-parallel allocation domain.

We propagate based on the following:

  1. If it is a device parallel type -> always propagate
  2. If it is a fusion input or output -> id is not stream parallelized
  3. If the stream ID in a tensorview is not mapped to stream ID in all of its consumers -> id is not stream parallelized

For cases like:

// allocation. This dimension should eventually be parallelized on `Stream`
, we want to start with replicating Stream-parallel ID, that is the allocation is not parallelized. However, this ID will appear in the logical domain due to rfactor and with the current contract, be allocated fully regardless of parallelization. So I am not making this a condition in the pass, yet.

This can be changed in future when we need.

Depends on #5363

@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Oct 8, 2025

Review updated until commit 46253fb

Description

  • Propagate Stream parallelization to allocation domain conditionally

  • Prevent allocation sharding for non-device, non-Stream tensors

  • Add tests for Stream-parallel allocation behavior

  • Print debug transforms in finalize pass for diagnostics


Changes walkthrough 📝

Relevant files
Enhancement
finalize_multidevice_domains.cpp
Implement conditional Stream allocation sharding                 

csrc/preseg_passes/finalize_multidevice_domains.cpp

  • Introduced shardAllocation to handle device and Stream parallelization
  • Added shouldParallelizeAllocationOnStream to check Stream consumer
    consistency
  • Added isLoopStreamParallelized to detect Stream in loop domain
  • Skip sharding if no device mesh and not Stream-loop parallelized
  • Print debug transform logging when enabled
  • +57/-30 
    Bug fix
    test_multidevice_lower_communication.cpp
    Fix allgather test device mesh setup                                         

    tests/cpp/test_multidevice_lower_communication.cpp

  • Move setDeviceMesh call before split for correctness
  • Remove manual split and allocation on output tensor
  • Use setDeviceMesh on output to enable proper propagation
  • +2/-3     
    Tests
    test_stream.cpp
    Add Stream allocation propagation tests                                   

    tests/cpp/test_stream.cpp

  • Add #include
  • Add ShardedAllocation test for Stream allocation in loops
  • Add ReplicatedAllocation test when Stream not in consumers
  • Verify allocation domain matches logical or loop domain
  • +58/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Loop Stream Check

    The function isLoopStreamParallelized checks if any loop domain ID is stream-parallel, but it may not account for nested or conditional loop structures where stream parallelization is context-dependent. This could lead to incorrect propagation decisions in complex loop scenarios.

    bool isLoopStreamParallelized(const TensorView* tv) {
      return std::any_of(
          tv->getLoopDomain().begin(),
          tv->getLoopDomain().end(),
          [](IterDomain* id) { return id->isStream(); });
    }
    Allocation Sharding Logic

    The shardAllocation function skips splitting for stream-parallel outer dimensions when shouldParallelizeAllocationOnStream returns false, but it does not handle cases where partial stream parallelization exists across consumer tensorviews, potentially leading to inconsistent memory layouts.

    if (split->outer()->isStream() &&
        !shouldParallelizeAllocationOnStream(tv)) {
      continue;
    }
    Test Coverage

    The new tests ShardedAllocation and ReplicatedAllocation verify basic behavior but do not test edge cases such as tensorviews with mixed parallel types or multiple stream-parallel dimensions, which could expose flaws in the propagation logic.

    TEST_F(StreamTest, ShardedAllocation) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const int64_t s = 2;
    
      TensorView* tv0 = makeContigTensor(3);
      TensorView* tv1 = add(tv0, IrBuilder::create<Val>(1.0));
      TensorView* tv2 = sum(tv1, {2});
      TensorView* tv3 = div(tv1, IrBuilder::create<Val>(2.0));
      fusion->addInput(tv0);
      fusion->addOutput(tv2);
      fusion->addOutput(tv3);
    
      tv0->outer_split(0, s);
      tv0->axis(0)->parallelize(ParallelType::Stream);
    
      preseg_passes::OptimizationPass<preseg_passes::PreSegmenter>::runPass(
          fusion.get());
    
      for (auto* tv : {tv0, tv1, tv2, tv3}) {
        EXPECT_TRUE(tv->axis(0)->isStream()) << tv;
        if (tv->isFusionOutput() || tv->isFusionInput()) {
          EXPECT_EQ(tv->getAllocationDomain(), tv->getLogicalDomain());
        } else {
          EXPECT_EQ(tv->getAllocationDomain(), tv->getLoopDomain());
        }
      }
    }
    
    TEST_F(StreamTest, ReplicatedAllocation) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const int64_t s = 2;
    
      TensorView* tv0 = makeContigTensor(3);
      TensorView* tv1 = add(tv0, IrBuilder::create<Val>(1.0));
      TensorView* tv2 = sum(tv1, {2});
      TensorView* tv3 = div(tv1, IrBuilder::create<Val>(2.0));
      fusion->addInput(tv0);
      fusion->addOutput(tv2);
      fusion->addOutput(tv3);
    
      tv0->outer_split(0, s);
      tv0->axis(0)->parallelize(ParallelType::Stream);
      tv2->outer_split(1, s);
      tv2->axis(1)->parallelize(ParallelType::Stream);
    
      preseg_passes::OptimizationPass<preseg_passes::PreSegmenter>::runPass(
          fusion.get());
      for (auto* tv : {tv0, tv1, tv2, tv3}) {
        EXPECT_TRUE(tv->axis(0)->isStream()) << tv;
        EXPECT_EQ(tv->getAllocationDomain(), tv->getLogicalDomain());
      }
    }

    --global and others added 2 commits October 8, 2025 14:44
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Base automatically changed from pm/index_compute to main October 9, 2025 00:57
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 marked this pull request as ready for review October 9, 2025 01:31
    @Priya2698 Priya2698 requested review from wujingyue and removed request for wujingyue October 9, 2025 01:32
    @Priya2698 Priya2698 marked this pull request as draft October 9, 2025 05:42
    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM otherwise

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 requested a review from wujingyue October 14, 2025 16:05
    @Priya2698 Priya2698 marked this pull request as ready for review October 14, 2025 16:05
    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Looks great!

    Priya2698 and others added 5 commits October 14, 2025 12:20
    Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
    Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
    Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 merged commit 851a0e6 into main Oct 16, 2025
    64 of 67 checks passed
    @Priya2698 Priya2698 deleted the pm/alloc_stream branch October 16, 2025 16:06
    split != nullptr,
    "Expected all transform exprs to be a split between allocation and "
    "loop domain during sharding propagation.");
    if (split->outer()->isStream() &&
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Nit: I believe you can move this filter to loop_stream_device_view as well. This way, we put all the filters in one location.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This PR was merged, but I'll do it in a follow-up!

    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    Issue #5309 
    Unlike device parallelization, a stream parallel tensorview (in loop)
    may or may not have a stream-parallel allocation domain.
    
    We propagate based on the following:
    1. If it is a device parallel type -> always propagate
    2. If it is a fusion input or output -> id is not stream parallelized
    3. If the stream ID in a tensorview is not mapped to stream ID in all of
    its consumers -> id is not stream parallelized
    
    For cases like:
    https://github.com/NVIDIA/Fuser/blob/f8e84e52296cdecd318dd2ce904139616d7bd434/tests/cpp/test_overlap.cpp#L155,
    we want to start with replicating Stream-parallel ID, that is the
    allocation is not parallelized. However, this ID will appear in the
    logical domain due to rfactor and with the current contract, be
    allocated fully regardless of parallelization. So I am not making this a
    condition in the pass, yet.
    
    This can be changed in future when we need.
    
    Depends on #5363
    
    ---------
    
    Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants