Skip to content

Fix ReplaySelf to reuse transforms for loop and allocation.#4585

Merged
wujingyue merged 2 commits intomainfrom
wjy/replay
Jun 5, 2025
Merged

Fix ReplaySelf to reuse transforms for loop and allocation.#4585
wujingyue merged 2 commits intomainfrom
wjy/replay

Conversation

@wujingyue
Copy link
Collaborator

No description provided.

@wujingyue wujingyue requested review from Priya2698 and jjsjann123 June 5, 2025 22:02
@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Jun 5, 2025

Description

  • Reuse ReplaySelf instance for loop and allocation

  • Ensure loop and allocation share transforms

  • Add test for loop and allocation replay


Changes walkthrough 📝

Relevant files
Enhancement
transform_replay.cpp
Reuse ReplaySelf for loop and allocation                                 

csrc/transform_replay.cpp

  • Removed redundant map comment
  • Created single ReplaySelf instance for loop and allocation
  • Reordered loop and allocation replay logic
  • Removed redundant ReplaySelf instantiation for allocation
  • +38/-32 
    Tests
    test_replay.cpp
    Add test for loop and allocation replay                                   

    tests/cpp/test_replay.cpp

  • Included additional testing headers
  • Added test case for loop and allocation replay
  • +21/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Code Duplication

    The code for replaying loop and allocation is very similar. Consider refactoring to avoid duplication.

    // We create one ReplaySelf instance to replay loop and allocation. This way,
    // loop and allocation share the same transforms if they are split the same
    // way.
    //
    // We use `self_loop` as the target domain because loop post-dominates
    // allocation.
    const std::vector<IterDomain*>& self_loop = self->loop();
    ReplaySelf replay(self_loop, axis_map);
    
    // Replay loop.
    if (self_loop != self->logical()) {
      std::vector<IterDomain*> new_loop;
      if (ignore_reductions) {
        for (auto* id : new_self->logical()) {
          if (id->isReduction()) {
            new_loop.push_back(id);
          }
        }
      }
    
      for (IterDomain* loop_id : self_loop) {
        if (ignore_reductions && loop_id->isReduction()) {
          continue;
        }
    
        auto it = replay.getReplay().find(loop_id);
        NVF_ERROR(
            it != replay.getReplay().end(),
            "failed to replay IterDomain: ",
            loop_id);
        it->second->parallelize(loop_id->getParallelType());
        new_loop.push_back(it->second);
      }
    
      new_self->setLoopDomain(new_loop);
    }
    
    // Replay allocation.
    if (self->hasAllocation()) {
      const std::vector<IterDomain*>& self_allocation = self->allocation();
      const std::vector<std::optional<bool>>& self_contiguity =
          self->contiguity();
      NVF_ERROR_EQ(self_allocation.size(), self_contiguity.size());
    
      std::vector<IterDomain*> new_alloc_domain;
      std::vector<std::optional<bool>> new_contiguity;
      new_alloc_domain.reserve(self_allocation.size());
      new_contiguity.reserve(self_contiguity.size());
    
      // Push back the reduction IDs that are not mapped
      if (ignore_reductions) {
        for (auto* id : new_self->logical()) {
          if (id->isReduction()) {
            new_alloc_domain.push_back(id);
            // NOLINTNEXTLINE(modernize-use-emplace)
            new_contiguity.push_back(std::nullopt);
          }
        }
      }
    
      // Pushing the mapped IDs and corresponding contiguity flags
      for (auto&& [alloc_id, contiguity] :
           zip(self_allocation, self_contiguity)) {
        if (ignore_reductions && alloc_id->isReduction()) {
          continue;
        }
        auto it = replay.getReplay().find(alloc_id);
        NVF_ERROR(
            it != replay.getReplay().end(),
            "failed to replay IterDomain: ",
            alloc_id);
        NVF_ERROR_EQ(
            it->second->isBroadcast(),
            !contiguity.has_value(),
            "Contiguity should be nullopt iff broadcast.");
        new_contiguity.push_back(contiguity);
        it->second->parallelize(alloc_id->getParallelType());
        new_alloc_domain.push_back(it->second);
      }
    
      new_self->setAllocationDomain(new_alloc_domain, new_contiguity);
    Logic Consistency

    Ensure that the logic for handling reductions in both loop and allocation replay is consistent and correct.

    std::vector<IterDomain*> new_loop;
    if (ignore_reductions) {
      for (auto* id : new_self->logical()) {
        if (id->isReduction()) {
          new_loop.push_back(id);
        }
      }
    }
    
    for (IterDomain* loop_id : self_loop) {
      if (ignore_reductions && loop_id->isReduction()) {
        continue;
      }
    Test Coverage

    Verify that the new test covers all edge cases and provides sufficient coverage for the changes made.

    TEST_F(ReplayTest, LoopAndAllocation) {
      Fusion fusion;
      FusionGuard fg(&fusion);
      TensorView* in = makeSymbolicTensor(1);
      TensorView* out = set(in);
      fusion.addInput(in);
      fusion.addOutput(out);
    
      constexpr int d = 2;
      in->setDeviceMesh(DeviceMesh::createForNumDevices(d));
      in->outer_split(0, d);
      in->setAllocationDomain(in->getLoopDomain(), true);
    
      TransformReplay::selfReplay(in->domain(), out->domain());
      EXPECT_THAT(out->getLoopDomain(), SizeIs(2));
      EXPECT_THAT(out->getLoopDomain(), ContainerEq(out->getAllocationDomain()));
      EXPECT_THAT(out->getContiguity(), Each(Optional(IsTrue())));
    }

    // way.
    //
    // We use `self_loop` as the target domain because loop post-dominates
    // allocation.
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Even though this felt a bit artificially restrictive, I think it simplifies our analysis and is the de facto protocol in code base.
    BTW, @wolfcomos

    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Thanks! This helps a lot with writing test for allocation domain transformation.

    @wujingyue wujingyue merged commit 43c9e13 into main Jun 5, 2025
    43 of 46 checks passed
    @wujingyue wujingyue deleted the wjy/replay branch June 5, 2025 23:34
    nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants