Skip to content

Do not attempt to cancel reshape when not all tensors are dominated#4823

Merged
naoyam merged 7 commits intomainfrom
cancel_reshape_war
Jul 25, 2025
Merged

Do not attempt to cancel reshape when not all tensors are dominated#4823
naoyam merged 7 commits intomainfrom
cancel_reshape_war

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Jul 23, 2025

I decided to disable the cancellation of reshape in the resize scheduler. It was originally added in #3679.

It results in about 10% perf regression in the RoPE benchmarks http://nv/eO-.

The optimization should be reenabled but rather than ad-hoc patching, I feel we should investigate fixing the root cause of the issue, which is cycles in the exact graph. Tracking issue: #4839

@naoyam
Copy link
Collaborator Author

naoyam commented Jul 23, 2025

!test --diff

@github-actions
Copy link

github-actions bot commented Jul 23, 2025

Review updated until commit 6c0e7f8

Description

  • Disabled reshape cancellation in ResizeScheduler to avoid scheduling errors.

  • Added performance regression comment in ResizeScheduler.

  • Skipped tests related to reshape cancellation.

  • Added a repro test for reshape cancellation issue.


Changes walkthrough 📝

Relevant files
Enhancement
resize.cpp
Disabled reshape cancellation in ResizeScheduler                 

csrc/scheduler/resize.cpp

  • Commented out the cancelReshapeInLoopDomains call.
  • Added comments explaining the performance regression and the need to
    address cyclic exact graphs.
  • +17/-3   
    Tests
    test_rope.cpp
    Skipped reshape cancellation tests                                             

    tests/cpp/test_rope.cpp

  • Skipped EndingRepeat and EndingRepeatWithNoBroadcastOp tests due to
    disabled reshape cancellation.
  • +3/-0     
    test_repro.py
    Added reshape cancellation repro test                                       

    tests/python/test_repro.py

  • Added a new test test_reshape_cancellation to reproduce the reshape
    cancellation issue.
  • +162/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Regression

    The PR disables the cancellation of reshape in the resize scheduler, resulting in up to 10% performance loss in RoPE benchmarks. It is important to investigate and address the root cause of the issue to potentially re-enable the optimization.

    // Disabled for now to avoid scheduling errors in some HF
    // models. Up to 10% perf loss is observed with the RoPE
    // benchmarks. To re-enable the optimization, it probably makes
    // more sense to first address the issue due to cyclic exact
    // graphs. That is, scheduleLoopDomainsLike is potentially fairly
    // powerful but due to cycles, only the update mode is used when
    // propagating transformations from the reference tensor. This
    // restriction makes it difficult to use more aggressive
    // scheduling like setting the loop domain of a reshape output
    // tensor as its root domain, which is what
    // cancelReshapeInLoopDomains does. See test_reshape_cancellation
    // for a repro.
    //
    // scheduler_tools::cancelReshapeInLoopDomains(
    // largest_input, /*skip_innermost_id=*/true);
    Disabled Tests

    Two tests, EndingRepeat and EndingRepeatWithNoBroadcastOp, are disabled due to the cancellation of reshape being disabled. It is crucial to understand the impact of this change on these specific test cases.

    TEST_F(RopeTest, EndingRepeat) {
      GTEST_SKIP() << "Disabled due to as cancelReshape is disabled";
      auto fusion_ptr = std::make_unique<Fusion>();
      FusionGuard fg(fusion_ptr.get());
      Fusion& fusion = *fusion_ptr;
    
      std::vector<int64_t> shape1{8, 126};
    
      auto tv0 = makeContigConcreteTensor(shape1);
      fusion.addInput(tv0);
    
      auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()});
      auto tv2 = repeat(tv1, {2, 1});
      auto tv3 = segment_set(tv2);
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto t0 = at::randn(shape1, options);
    
      FusionExecutorCache executor_cache(std::move(fusion_ptr));
      auto outputs = executor_cache.runFusionWithInputs({t0});
      testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__);
    
      FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
      EXPECT_FALSE(runtime->isSegmented());
      const auto& heuristic_param =
          runtime->schedulerHeuristics()->heuristicsList().front();
      EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize);
      Fusion* scheduled_fusion = runtime->executors()
                                     .at(0)
                                     ->as<KernelExecutor>()
                                     ->compiledKernel()
                                     ->kernel();
    
      // Check the loop domain of the reference. It should look like:
      //
      // T4_g_float[iS19{2 ex 2}, iblockIdx.x22{8}, ithreadIdx.x23{128}] ca_pos( 3 )
      // produce_pos( 3 )
      //  logical domain : (iS17{( 2 * 8 )}, iS18{128})
      //  contiguity: t t
      //   Merge: iS20{8} and iS18{128} -> iS21{1024}
      //   Split: iS21{1024} by factor 128 -> iblockIdx.x22{8}, ithreadIdx.x23{128}
      //  loop domain : (iS19{2 ex 2}, iblockIdx.x22{8}, ithreadIdx.x23{128})
      //
      // iS19 is the repeat ID, which should be just a Serial ID with an
      // extent of 2.
      auto ref_tv = scheduled_fusion->outputs().at(0)->as<TensorView>();
      // The outermost loop ID should be a Serial ID with an extent of 2.
      EXPECT_EQ(
          ref_tv->getLoopDomain().at(0)->getParallelType(), ParallelType::Serial);
      EXPECT_TRUE(ref_tv->getLoopDomain().at(0)->extent()->isConstInt());
      EXPECT_EQ(
          ref_tv->getLoopDomain().at(0)->extent()->evaluate().as<int64_t>(), 2L);
    
      IdModel id_model(scheduled_fusion, /*build_graphs=*/false);
      const auto& exact_graph = id_model.buildExactGraph();
    
      const auto ref_loop = exact_graph.toGroups(ref_tv->getLoopDomain());
    
      // The other tensors, except for the pad output, should be fully inlined into
      // the reference tensor.
      for (auto tv : scheduled_fusion->allTvs()) {
        if (tv->isFusionInput()) {
          continue;
        }
        auto tv_loop = exact_graph.toGroups(tv->getLoopDomain());
        if (tv->definition() != nullptr && tv->definition()->isA<PadOp>()) {
          ValGroups ref_groups{ref_loop.begin() + 1, ref_loop.end()};
          // In the case of pad, the loop domain of the output tensor
          // should be mapped with the loop domain of the reference
          // without the outermost ID.
          EXPECT_EQ(tv_loop, ref_groups);
        } else {
          EXPECT_EQ(tv_loop, ref_loop);
          EXPECT_EQ(tv->getLoopDomain().size(), tv->getComputeAtPosition());
        }
      }
    }
    
    // Similar to EndingRepeat but with a broadcast ID already found in an
    // input tensor. A similar Pattern appears in the LitGPT Llama RoPE
    // module.
    TEST_F(RopeTest, EndingRepeatWithNoBroadcastOp) {
      GTEST_SKIP() << "Disabled due to as cancelReshape is disabled";
    
      auto fusion_ptr = std::make_unique<Fusion>();
    Repro Test

    A new test test_reshape_cancellation is added to reproduce the issue. Ensure that this test accurately captures the problem and that the performance regression is consistently observed.

    # Repro of https://github.com/NVIDIA/Fuser/pull/4823
    def test_reshape_cancellation(self):
        def nvfuser_fusion_id1(fd: FusionDefinition) -> None:
            T0 = fd.define_tensor(
                shape=[1, 2048, 24, 32],
                contiguity=[None, True, True, False],
                dtype=DataType.BFloat16,
                is_cpu=False,
                stride_order=[3, 2, 1, 0],
            )
            T1 = fd.define_tensor(
                shape=[1, 2048, 24, 32],
                contiguity=[None, True, True, False],
                dtype=DataType.BFloat16,
                is_cpu=False,
                stride_order=[3, 2, 1, 0],
            )
            T2 = fd.define_tensor(
                shape=[1, 2048, 24, 32],
                contiguity=[None, True, True, False],
                dtype=DataType.BFloat16,
                is_cpu=False,
                stride_order=[3, 2, 1, 0],
            )
            T3 = fd.define_tensor(
                shape=[1, 2048, 4, 4608],
                contiguity=[None, True, True, True],
                dtype=DataType.BFloat16,
                is_cpu=False,
                stride_order=[3, 2, 1, 0],
            )
            T4 = fd.define_tensor(
                shape=[1, 2048, 24, 32],
                contiguity=[None, True, True, False],
                dtype=DataType.BFloat16,
                is_cpu=False,
                stride_order=[3, 2, 1, 0],
            )
            T5 = fd.define_tensor(
                shape=[1, 2048, 24, 64],
                contiguity=[None, True, None, True],
                dtype=DataType.Float,
                is_cpu=False,
                stride_order=[3, 2, 1, 0],
            )
            T6 = fd.define_tensor(
                shape=[1, 2048, 24, 64],
                contiguity=[None, True, True, True],
                dtype=DataType.Float,
                is_cpu=False,
                stride_order=[3, 2, 1, 0],
            )
            T7 = fd.define_tensor(
                shape=[1, 2048, 24, 64],
                contiguity=[None, True, True, True],
                dtype=DataType.Float,
                is_cpu=False,
                stride_order=[3, 2, 1, 0],
            )
            T8 = fd.ops.cast(T0, dtype=DataType.Float)
            T9 = fd.ops.neg(T8)
            T10 = fd.ops.cast(T9, dtype=DataType.BFloat16)
            T17 = fd.ops.broadcast_in_dim(
                T1, shape=[1, 2048, 24, 32, 1], broadcast_dims=[0, 1, 2, 3]
            )
            T24 = fd.ops.broadcast_in_dim(
                T10, shape=[1, 2048, 24, 32, 1], broadcast_dims=[0, 1, 2, 3]
            )
            T25 = fd.ops.cast(T2, dtype=DataType.Float)
            T41 = fd.ops.slice(
                T3,
                start_indices=[0, 0, 0, 3072],
                end_indices=[1, 2048, 4, 4608],
                strides=[1, 1, 1, 1],
                manual_normalization=0,
            )
            T42 = fd.ops.cat([T24, T17], dim=-1, manual_padding=0)
            T43 = fd.ops.neg(T25)
            T50 = fd.ops.reshape(T41, new_shape=[1, 2048, 4, 6, 256])
            T56 = fd.ops.reshape(T42, new_shape=[1, 2048, 24, 64])
            T57 = fd.ops.cast(T43, dtype=DataType.BFloat16)
            T63 = fd.ops.reshape(T50, new_shape=[1, 2048, 24, 256])
            T64 = fd.ops.cast(T56, dtype=DataType.Float)
            T71 = fd.ops.broadcast_in_dim(
                T4, shape=[1, 2048, 24, 32, 1], broadcast_dims=[0, 1, 2, 3]
            )
            T78 = fd.ops.broadcast_in_dim(
                T57, shape=[1, 2048, 24, 32, 1], broadcast_dims=[0, 1, 2, 3]
            )
            T94 = fd.ops.slice(
                T63,
                start_indices=[0, 0, 0, 64],
                end_indices=[1, 2048, 24, 256],
                strides=[1, 1, 1, 1],
                manual_normalization=0,
            )
            T95 = fd.ops.mul(T64, T5)
            T111 = fd.ops.slice(
                T3,
                start_indices=[0, 0, 0, 0],
                end_indices=[1, 2048, 4, 1536],
                strides=[1, 1, 1, 1],
                manual_normalization=0,
            )
            T112 = fd.ops.cat([T78, T71], dim=-1, manual_padding=0)
            T113 = fd.ops.cast(T94, dtype=DataType.Float)
            T114 = fd.ops.add(T6, T95)
            T121 = fd.ops.reshape(T111, new_shape=[1, 2048, 4, 6, 256])
            T127 = fd.ops.reshape(T112, new_shape=[1, 2048, 24, 64])
            T128 = fd.ops.cat([T114, T113], dim=-1, manual_padding=0)
            T134 = fd.ops.reshape(T121, new_shape=[1, 2048, 24, 256])
            T135 = fd.ops.cast(T127, dtype=DataType.Float)
            T136 = fd.ops.permute(T128, dims=[0, 2, 1, 3])
            T152 = fd.ops.slice(
                T134,
                start_indices=[0, 0, 0, 64],
                end_indices=[1, 2048, 24, 256],
                strides=[1, 1, 1, 1],
                manual_normalization=0,
            )
            T153 = fd.ops.mul(T135, T5)
            T154 = fd.ops.cast(T136, dtype=DataType.BFloat16)
            T155 = fd.ops.cast(T152, dtype=DataType.Float)
            T156 = fd.ops.add(T7, T153)
            T157 = fd.ops.cat([T156, T155], dim=-1, manual_padding=0)
            T158 = fd.ops.permute(T136, dims=[0, 1, 3, 2])
            T159 = fd.ops.permute(T157, dims=[0, 2, 1, 3])
            fd.add_output(T159)
            fd.add_output(T154)
            fd.add_output(T158)
    
        with FusionDefinition() as fd:
            nvfuser_fusion_id1(fd)
    
        inputs = [
            torch.randn(3145727, dtype=torch.bfloat16, device="cuda:0").as_strided(
                (1, 2048, 24, 32), (3145728, 1536, 64, 2)
            ),
            torch.randn(3145727, dtype=torch.bfloat16, device="cuda:0").as_strided(
                (1, 2048, 24, 32), (3145728, 1536, 64, 2)
            ),
            torch.randn(3145727, dtype=torch.bfloat16, device="cuda:0").as_strided(
                (1, 2048, 24, 32), (3145728, 1536, 64, 2)
            ),
            torch.testing.make_tensor(
                (1, 2048, 4, 4608), dtype=torch.bfloat16, device="cuda:0"
            ),
            torch.randn(3145727, dtype=torch.bfloat16, device="cuda:0").as_strided(
                (1, 2048, 24, 32), (3145728, 1536, 64, 2)
            ),
            torch.randn(131072, dtype=torch.float32, device="cuda:0").as_strided(
                (1, 2048, 24, 64), (131072, 64, 0, 1)
            ),
            torch.testing.make_tensor(
                (1, 2048, 24, 64), dtype=torch.float32, device="cuda:0"
            ),
            torch.testing.make_tensor(
                (1, 2048, 24, 64), dtype=torch.float32, device="cuda:0"
            ),
        ]
        fd.execute(inputs)

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jul 24, 2025

    !test --pybench

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jul 24, 2025

    !test

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jul 24, 2025

    !test

    @naoyam naoyam marked this pull request as ready for review July 24, 2025 17:51
    @naoyam naoyam requested a review from jjsjann123 July 24, 2025 17:51
    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    stamping to unblock. I'll let you make you decision on whether to add a separate repro for tracking the follow up.

    // cancelReshapeInLoopDomains does.
    //
    // scheduler_tools::cancelReshapeInLoopDomains(
    // largest_input, /*skip_innermost_id=*/true);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nitpick: if we are adding a repro. Let's put a comment here linking the repro that leads to the decision to disable the optimization.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jul 24, 2025

    !test

    @naoyam naoyam merged commit 1c66e47 into main Jul 25, 2025
    46 of 49 checks passed
    @naoyam naoyam deleted the cancel_reshape_war branch July 25, 2025 04:05
    nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
    …VIDIA#4823)
    
    I decided to disable the cancellation of reshape in the resize
    scheduler. It was originally added in NVIDIA#3679.
    
    It results in about 10% perf regression in the RoPE benchmarks
    http://nv/eO-.
    
    The optimization should be reenabled but rather than ad-hoc patching, I
    feel we should investigate fixing the root cause of the issue, which is
    cycles in the exact graph. Tracking issue: NVIDIA#4839
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants