Skip to content

Forward full op#4269

Merged
naoyam merged 9 commits intomainfrom
forward_factory_op
Apr 24, 2025
Merged

Forward full op#4269
naoyam merged 9 commits intomainfrom
forward_factory_op

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Apr 17, 2025

Fusion segmenter sets aside a certain sequence of unary ops starting with fusion inputs, which we call forwarding. It effectively works as an optimization by recomputing (cheap) unary ops instead of passing tensors from one segment to another.

This PR extends the forwarding optimization to those starting with factory methods. Here's a motivating example (Litgpt Llama 3 RoPE backward):

llama_bwd

The T81 tensor is the output a full op. The tensor is used inside both yellow and gray segments. The op itself is in the yellow segment, so it's created inside the yellow segment, and that is passed, through gmem, to the gray segment. Obviously, cheap ops like this should be just replicated in the gray segment instead of passing a full tensor. Here's another way to see it:

g{(resize)
group id: 4
inputs:
  T1_g___bfloat[bS3{1}, iS4{32}, iS5{8192}, iS6{128}] __bfloat
  T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}] __bfloat
  T9_g___bfloat[bS38{1}, bS39{1 ex 32}, iS40{8192}, iS41{128}] __bfloat
outputs:
  T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] __bfloat
  T54_g___bfloat[bS233{1}, iS238{8}rf, iS239{4}rf, iS235{8192}, iS236{128}] __bfloat
  T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}] __bfloat


T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}]
   = full({1, 32, 8192, 128}, __bfloat(0));
(121)
...

And T81 is used in the next segment of:

g{(resize)
group id: 3
inputs:
  T18_g___bfloat[bS79{1}, iS80{32}, iS81{8192}, iS82{128}] __bfloat
  T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] __bfloat
  T34_g___bfloat[bS144{1}, iS145{32}, iS146{8192}, iS147{128}] __bfloat
  T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}] __bfloat
outputs:
  T75_g___bfloat[bS328{1}, iS329{8}, iS331{6}rf, iS332{8192}, iS333{128}] __bfloat


T50_l___bfloat[bS212{1}, iS213{32}, iS214{8192}, iS216{64}rf]
   = slice( T34_g___bfloat[bS144{1}, iS145{32}, iS146{8192}, iS147{128}], { {0, 1, 1} {0, 32, 1} {0, 8192, 1} {64, 128, 1} } )
(52)
T55_g___bfloat[bS240{1}, iS241{32}, iS242{8192}, iS244{128}rf]
   = pad( T50_l___bfloat[bS212{1}, iS213{32}, iS214{8192}, iS216{64}rf], {0, 0, 0, 0, 0, 0, 0, 64} )
(61)
T39_g___bfloat[bS166{1}, iS167{32}, iS168{8192}, iS170{64}rf]
   = slice( T34_g___bfloat[bS144{1}, iS145{32}, iS146{8192}, iS147{128}], { {0, 1, 1} {0, 32, 1} {0, 8192, 1} {0, 64, 1} } )
(39)
T43_l_float[bS184{1}, iS185{32}, iS186{8192}, iS187{64}]
   = __bfloat2float(T39_g___bfloat[bS166{1}, iS167{32}, iS168{8192}, iS170{64}rf]);
(44)
T46_l_float[bS196{1}, iS197{32}, iS198{8192}, iS199{64}]
   = -T43_l_float[bS184{1}, iS185{32}, iS186{8192}, iS187{64}];
(47)
T48_g___bfloat[bS204{1}, iS205{32}, iS206{8192}, iS207{64}]
   = __float2bfloat(T46_l_float[bS196{1}, iS197{32}, iS198{8192}, iS199{64}]);
(49)
T51_g___bfloat[bS217{1}, iS218{32}, iS219{8192}, iS221{128}rf]
   = pad( T48_g___bfloat[bS204{1}, iS205{32}, iS206{8192}, iS207{64}], {0, 0, 0, 0, 0, 0, 64, 0} )
(54)
T38_l_float[bS162{1}, iS163{32}, iS164{8192}, iS165{128}]
   = __bfloat2float(T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}]);
(101)
...

There are multiple ways to achieve that. What seems to most make sense to me is to extend the existing forwarding method to handle cases like this. The existing method only considers ops starting with fusion inputs, which do not include factory-created tensors.

This PR applies a small change to the forwarding logic to include factory ops as well. The end result of this change with the above example case is that the full result is no longer passed around. Here's the first segment:

g{(resize)
group id: 3
inputs:
  T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}] __bfloat
  T1_g___bfloat[bS3{1}, iS4{32}, iS5{8192}, iS6{128}] __bfloat
  T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}] __bfloat
outputs:
  T49_g___bfloat[bS208{1}, iS209{32}, iS210{8192}, iS211{128}] __bfloat


T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}]
   = broadcast( T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}], flags = {false, true, false, false} )
(16)
T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] = expand( T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}], {1, 32, 8192, 128} )
(129)
T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}]
   = broadcast( T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}], flags = {false, true, false, false} )
(0)
T9_g___bfloat[bS38{1}, bS39{1 ex 32}, iS40{8192}, iS41{128}] = expand( T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}], {1, 32, 8192, 128} )
(128)
T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}]
   = full({1, 32, 8192, 128}, __bfloat(0));
...

Notice that T81 is no longer a segment output. And the second segment is:

g{(resize)
group id: 4
inputs:
  T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}] __bfloat
  T2_g___bfloat[bS7{1}, iS8{32}, iS9{8192}, iS10{128}] __bfloat
  T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}] __bfloat
outputs:
  T74_g___bfloat[bS321{1}, iS326{8}rf, iS327{4}rf, iS323{8192}, iS324{128}] __bfloat


T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}]
   = broadcast( T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}], flags = {false, true, false, false} )
(16)
T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] = expand( T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}], {1, 32, 8192, 128} )
(129)
T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}]
   = broadcast( T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}], flags = {false, true, false, false} )
(0)
T9_g___bfloat[bS38{1}, bS39{1 ex 32}, iS40{8192}, iS41{128}] = expand( T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}], {1, 32, 8192, 128} )
(128)
T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}]
   = full({1, 32, 8192, 128}, __bfloat(0));
(121)
...

@naoyam
Copy link
Collaborator Author

naoyam commented Apr 17, 2025

!test --diff

@github-actions
Copy link

github-actions bot commented Apr 17, 2025

Review updated until commit 329f19a

Description

  • Extend forwarding optimization to factory methods like FullOp

  • Prevent merging of auxiliary input groups

  • Add test for forwarding FullOp across segments


Changes walkthrough 📝

Relevant files
Enhancement
fusion_segmenter.cpp
Extend forwarding to factory methods                                         

csrc/fusion_segmenter.cpp

  • Prevent merging of auxiliary input groups
  • Extend forwarding optimization to include FullOp
  • Update logic to handle factory-created tensors
  • +59/-6   
    fusion_segmenter.h
    Add method for auxiliary groups                                                   

    csrc/fusion_segmenter.h

    • Add method to get auxiliary input groups
    +3/-0     
    Tests
    test_segmentation.cpp
    Add test for FullOp forwarding                                                     

    tests/cpp/test_segmentation.cpp

  • Rename test cases to use SegmentationTest
  • Add test for forwarding FullOp across segments
  • +57/-2   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Evaluation

    Ensure thorough performance evaluation is provided to demonstrate the benefits of extending forwarding optimization to factory methods.

    void SegmentCandidateFinder::forwardInputs() {
      excluded_inp_unary_exprs_ = {};
      input2group_.clear();
    
      std::vector<Val*> extended_fusion_inputs = completeFusion()->inputs();
    
      // Grab factory ops that should be forwarded. Add created tensors to
      // the fusion input list to make them handled like fusion inputs
      // TODO: Handle more factory methods such as IotaOp, EyeOp,
      // TensorConstruct. Probably should not include relatively expensive
      // ops like RNGOp.
      for (auto expr : completeFusion()->exprs()) {
        if (expr->isA<FullOp>() &&
            // Don't bother if it's a fusion output
            !expr->output(0)->isFusionOutput()) {
          extended_fusion_inputs.push_back(expr->output(0));
          excluded_inp_unary_exprs_.pushBack(expr);
        }
      }
    
      // "Terminating" outputs from the excluded input unary exprs, these will be
      // treated as complete fusion inputs.
      VectorOfUniqueEntries<Val*> forwarded_inputs;
      {
        std::deque<UnaryOp*> to_visit;
        for (Val* inp : extended_fusion_inputs) {
          if (UnaryOp* unary_use = shouldForward(inp)) {
            to_visit.push_back(unary_use);
          }
        }
    
        while (!to_visit.empty()) {
          UnaryOp* uop = to_visit.front();
          to_visit.pop_front();
    
          if (UnaryOp* unary_use = shouldForward(uop->out())) {
            to_visit.push_back(unary_use);
          } else {
            // We cannot extend the chain of unary ops, so we finalize this chain by
            // saving its output as a forwarded input.
            forwarded_inputs.pushBack(uop->out());
          }
          // Either way, `uop` is excluded from merging until
    Factory Method Handling

    Consider handling more factory methods such as IotaOp, EyeOp, TensorConstruct, and ensure that relatively expensive ops like RNGOp are not included.

      if (expr->isA<FullOp>() &&
          // Don't bother if it's a fusion output
          !expr->output(0)->isFusionOutput()) {
        extended_fusion_inputs.push_back(expr->output(0));
        excluded_inp_unary_exprs_.pushBack(expr);
      }
    }
    Test Coverage

    Ensure that the new test cases cover a variety of scenarios and edge cases to validate the correctness of the forwarding optimization for factory methods.

    TEST_F(SegmentationTest, ForwardFull) {
      auto fusion_ptr = std::make_unique<Fusion>();
      auto& fusion = *fusion_ptr;
      FusionGuard fg(fusion_ptr.get());
    
      auto tv0 = makeSymbolicTensor(1);
      fusion.addInput(tv0);
    
      // FullOp that is used in two segments
      auto tv1 = full({tv0->axis(0)->extent()}, fusion.oneVal(), DataType::Float);
    
      auto tv2 = add(tv0, tv1);
      auto tv3 = segment_set(tv2);
    
      auto tv4 = add(tv3, tv1);
      fusion.addOutput(tv4);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto t0 = at::randn({1024}, options);
    
      FusionExecutorCache executor_cache(std::move(fusion_ptr));
      auto outputs = executor_cache.runFusionWithInputs({t0});
      testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__);
    
      FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
      EXPECT_THAT(runtime->fusionSegments()->groups(), SizeIs(2));
    
      // Make sure the full output should not be a segment input
      for (const auto& executor : runtime->executors()) {
        auto ke = dynamic_cast<KernelExecutor*>(executor.get());
        ASSERT_NE(ke, nullptr);
        kir::Kernel* kernel = ke->compiledKernel()->kernel();
        bool full_op_found = false;
        for (auto expr : KernelExprVisitor::getAllExprs(kernel)) {
          auto out_tv = ir_utils::getTvOutput(expr);
          if (out_tv == nullptr) {
            continue;
          }
          auto full_op = dynamic_cast<FullOp*>(out_tv->definition());
          if (full_op == nullptr) {
            continue;
          }
          full_op_found = true;
          auto output_it =
              std::ranges::find_if(kernel->outputs(), [&](Val* output) {
                return output->isA<TensorView>() &&
                    output->name() == out_tv->name();
              });
          EXPECT_EQ(output_it, kernel->outputs().end())
              << "FullOp ouput should not be a segment output";
        }
        EXPECT_TRUE(full_op_found) << "Each segment has its own FullOp";
      }
    }

    @naoyam naoyam added the rope label Apr 17, 2025
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 17, 2025

    !test --diff

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 18, 2025

    !test --diff

    @naoyam naoyam marked this pull request as ready for review April 23, 2025 02:28
    @naoyam naoyam requested a review from jjsjann123 April 23, 2025 02:28
    SegmentedGroup* group = to_visit.front();
    to_visit.pop_front();

    if (group->exprs().empty()) {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Not strictly related, but while testing this PR, I was hitting a bug that should been easily spotted if this was in place.


    // Grab factory ops that should be forwarded. Add created tensors to
    // the fusion input list to make them handled like fusion inputs
    // TODO: Handle more factory methods such as IotaOp, EyeOp,
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I was originally planning to add support for other factory ops but decided to leave them as follow-up extensions.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 23, 2025

    !test --diff

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 23, 2025

    Codediff results

    TestNvFuserFrontend::test_issue1872

    This seems expected because the full-op result is used in two slices.

    jit_codegen_diff_20_6/7, jit_codegen_diff_20_7/7

    Something seems off in these two cases. They are showing lots of new and removed tests. Maybe the script somehow fails to match two results? CC: @jacobhinkle

    for (auto expr : completeFusion()->exprs()) {
    if (expr->isA<FullOp>() &&
    // Don't bother if it's a fusion output
    !expr->output(0)->isFusionOutput()) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    should we not forward even though they are fusion outputs, that would still save bandwidth on reading gmem in other consumers right?!

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    That's true in theory, but in practice it'll be likely cached anyway, so the impact would be minimal.

    Implementation wise, we would need to further extend the forwarding logic since it's likely to result in causing something unexpected if multiple segments had the same op producing the same fusion output.

    I just don't think it's worthwhile supporting such a case.

    @naoyam naoyam merged commit a958bfc into main Apr 24, 2025
    57 of 60 checks passed
    @naoyam naoyam deleted the forward_factory_op branch April 24, 2025 00:00
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants