Skip to content

[Experimental] Scheduling with AlmostExactSplit graph#4404

Closed
naoyam wants to merge 4 commits intomainfrom
almost_exact_split
Closed

[Experimental] Scheduling with AlmostExactSplit graph#4404
naoyam wants to merge 4 commits intomainfrom
almost_exact_split

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented May 9, 2025

No description provided.

@github-actions
Copy link

github-actions bot commented May 9, 2025

Review updated until commit 429a845

Description

  • Added mapAlmostExactSplits function to id_model.cpp

  • Updated LoopDomainScheduler to accept an optional ValGraph

  • Added new tests for AlmostExactSplitGraph


Changes walkthrough 📝

Relevant files
Enhancement
id_model.cpp
Implement AlmostExactSplit Graph Mapping                                 

csrc/id_model/id_model.cpp

  • Added mapAlmostExactSplits function to handle almost exact split graph
    mapping
  • Added logic to handle graph mapping in mapAlmostExactSplits
  • +139/-0 
    loop_domain_scheduler.cpp
    Update LoopDomainScheduler for Graph Mapping                         

    csrc/scheduler/tools/loop_domain_scheduler.cpp

  • Updated LoopDomainScheduler constructor to accept an optional ValGraph
  • Modified graph() method to use the provided ValGraph if available
  • +17/-9   
    id_model.h
    Declare mapAlmostExactSplits in Header                                     

    csrc/id_model/id_model.h

    • Declared mapAlmostExactSplits function in header file
    +2/-0     
    loop_domain_scheduler.h
    Update scheduleLoopDomainsLike for Graph Mapping                 

    csrc/scheduler/tools/loop_domain_scheduler.h

  • Updated scheduleLoopDomainsLike function to accept an optional
    ValGraph
  • +2/-1     
    Tests
    test_id_model.cpp
    Add Tests for AlmostExactSplitGraph                                           

    tests/cpp/test_id_model.cpp

    • Added multiple test cases for AlmostExactSplitGraph
    +229/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Debugging Output

    The code contains multiple std::cerr statements which are used for debugging. These should be removed or replaced with proper logging before merging.

            const ValGroup& vg) -> std::vector<std::pair<ExprGroup, ExprGroup>> {
      std::vector<std::pair<ExprGroup, ExprGroup>> l1_r2_splits;
    
      if (!new_graph.hasUses(vg)) {
        return {};
      }
    
      for (const ExprGroup& use_of_vg : new_graph.getUses(vg)) {
        auto split_of_vg = dynamic_cast<Split*>(use_of_vg->front());
        if (split_of_vg == nullptr) {
          continue;
        }
    
        // mn
        const ValGroup& inner_group = new_graph.toGroup(split_of_vg->inner());
    
        if (!new_graph.hasUses(inner_group)) {
          return {};
        }
    
        for (const ExprGroup& use_of_inner_group :
             new_graph.getUses(inner_group)) {
          auto split_of_inner_group =
              dynamic_cast<Split*>(use_of_inner_group->front());
          if (split_of_inner_group == nullptr) {
            continue;
          }
    
          // This split needs to be divisible
          auto extent = split_of_inner_group->in()->extent();
          auto factor = split_of_inner_group->factor();
          if (extent->isConstScalar() && factor->isConstScalar() &&
              (extent->evaluate().as<int64_t>() %
                   factor->evaluate().as<int64_t>() !=
               0)) {
            continue;
          }
    
          l1_r2_splits.emplace_back(use_of_vg, use_of_inner_group);
    
          std::cerr << "L1R2 found: " << split_of_vg->toString()
                    << split_of_inner_group->toString();
        }
      }
    
      return l1_r2_splits;
    };
    
    auto get_matching_l2r1_splits =
        [&new_graph](
            const ValGroup& vg, const std::pair<ExprGroup, ExprGroup>& l1_r2)
        -> std::optional<std::pair<ExprGroup, ExprGroup>> {
      auto m = l1_r2.second->front()->as<Split>()->outer()->extent();
      auto n = l1_r2.second->front()->as<Split>()->inner()->extent();
    
      for (const ExprGroup& use_of_vg : new_graph.getUses(vg)) {
        auto split_of_vg = dynamic_cast<Split*>(use_of_vg->front());
        if (split_of_vg == nullptr) {
          continue;
        }
    
        if (!split_of_vg->inner()->extent()->sameAs(n)) {
          continue;
        }
    
        // I0/n
        const ValGroup& outer_group = new_graph.toGroup(split_of_vg->outer());
    
        if (!new_graph.hasUses(outer_group)) {
          return {};
        }
    
        for (const ExprGroup& use_of_outer_group :
             new_graph.getUses(outer_group)) {
          auto split_of_outer_group =
              dynamic_cast<Split*>(use_of_outer_group->front());
          if (split_of_outer_group == nullptr) {
            continue;
          }
    
          if (!split_of_outer_group->inner()->extent()->sameAs(m)) {
            continue;
          }
    
          std::cerr << "Matching L2R1 found: " << split_of_vg->toString()
                    << split_of_outer_group->toString();
          return std::make_pair(use_of_vg, use_of_outer_group);
        }
      }
    
      return std::nullopt;
    };
    
    std::vector<std::pair<ValGroup, ValGroup>> groups_to_map;
    
    for (const ValGroup& vg : new_graph.disjointValSets().disjointSets()) {
      const auto all_l1r2_splits = get_l1r2_splits(vg);
      for (const auto& l1r2 : all_l1r2_splits) {
        std::cerr << "L1R2: " << l1r2.first->front()->toString()
                  << l1r2.second->front()->toString();
        auto l2r1 = get_matching_l2r1_splits(vg, l1r2);
        if (!l2r1.has_value()) {
          continue;
        }
    
        std::cerr << "Found\n";
    
        auto l1r2_first_outputs = new_graph.outputGroups(l1r2.first);
        auto l1r2_second_outputs = new_graph.outputGroups(l1r2.second);
    
        auto l2r1_first_outputs = new_graph.outputGroups(l2r1->first);
        auto l2r1_second_outputs = new_graph.outputGroups(l2r1->second);
    
        groups_to_map.emplace_back(
            l1r2_first_outputs.at(0), l2r1_second_outputs.at(0));
        groups_to_map.emplace_back(
            l1r2_second_outputs.at(0), l2r1_second_outputs.at(1));
        groups_to_map.emplace_back(
            l1r2_second_outputs.at(1), l2r1_first_outputs.at(1));
      }
    }
    
    for (const auto& [vg1, vg2] : groups_to_map) {
      std::cerr << "Mapping " << nvfuser::toString(vg1) << ", "
                << vg1->front()->toString() << " and " << nvfuser::toString(vg2)
                << ", " << vg2->front()->toString() << "\n";
      new_graph.mapVals(vg1->front(), vg2->front());
    }
    Error Handling

    The function mapAlmostExactSplits returns an empty vector if certain conditions are not met. It would be beneficial to add more detailed error handling or logging to understand why the function might fail to find the expected splits.

    ValGraph mapAlmostExactSplits(const ValGraph& graph) {
      auto new_graph = graph;
    
      // vg: I0
      auto get_l1r2_splits =
          [&new_graph](
              const ValGroup& vg) -> std::vector<std::pair<ExprGroup, ExprGroup>> {
        std::vector<std::pair<ExprGroup, ExprGroup>> l1_r2_splits;
    
        if (!new_graph.hasUses(vg)) {
          return {};
        }
    
        for (const ExprGroup& use_of_vg : new_graph.getUses(vg)) {
          auto split_of_vg = dynamic_cast<Split*>(use_of_vg->front());
          if (split_of_vg == nullptr) {
            continue;
          }
    
          // mn
          const ValGroup& inner_group = new_graph.toGroup(split_of_vg->inner());
    
          if (!new_graph.hasUses(inner_group)) {
            return {};
          }
    
          for (const ExprGroup& use_of_inner_group :
               new_graph.getUses(inner_group)) {
            auto split_of_inner_group =
                dynamic_cast<Split*>(use_of_inner_group->front());
            if (split_of_inner_group == nullptr) {
              continue;
            }
    
            // This split needs to be divisible
            auto extent = split_of_inner_group->in()->extent();
            auto factor = split_of_inner_group->factor();
            if (extent->isConstScalar() && factor->isConstScalar() &&
                (extent->evaluate().as<int64_t>() %
                     factor->evaluate().as<int64_t>() !=
                 0)) {
              continue;
            }
    
            l1_r2_splits.emplace_back(use_of_vg, use_of_inner_group);
    
            std::cerr << "L1R2 found: " << split_of_vg->toString()
                      << split_of_inner_group->toString();
          }
        }
    
        return l1_r2_splits;
      };
    
      auto get_matching_l2r1_splits =
          [&new_graph](
              const ValGroup& vg, const std::pair<ExprGroup, ExprGroup>& l1_r2)
          -> std::optional<std::pair<ExprGroup, ExprGroup>> {
        auto m = l1_r2.second->front()->as<Split>()->outer()->extent();
        auto n = l1_r2.second->front()->as<Split>()->inner()->extent();
    
        for (const ExprGroup& use_of_vg : new_graph.getUses(vg)) {
          auto split_of_vg = dynamic_cast<Split*>(use_of_vg->front());
          if (split_of_vg == nullptr) {
            continue;
          }
    
          if (!split_of_vg->inner()->extent()->sameAs(n)) {
            continue;
          }
    
          // I0/n
          const ValGroup& outer_group = new_graph.toGroup(split_of_vg->outer());
    
          if (!new_graph.hasUses(outer_group)) {
            return {};
          }
    
          for (const ExprGroup& use_of_outer_group :
               new_graph.getUses(outer_group)) {
            auto split_of_outer_group =
                dynamic_cast<Split*>(use_of_outer_group->front());
            if (split_of_outer_group == nullptr) {
              continue;
            }
    
            if (!split_of_outer_group->inner()->extent()->sameAs(m)) {
              continue;
            }
    
            std::cerr << "Matching L2R1 found: " << split_of_vg->toString()
                      << split_of_outer_group->toString();
            return std::make_pair(use_of_vg, use_of_outer_group);
          }
        }
    
        return std::nullopt;
      };
    
      std::vector<std::pair<ValGroup, ValGroup>> groups_to_map;
    
      for (const ValGroup& vg : new_graph.disjointValSets().disjointSets()) {
        const auto all_l1r2_splits = get_l1r2_splits(vg);
        for (const auto& l1r2 : all_l1r2_splits) {
          std::cerr << "L1R2: " << l1r2.first->front()->toString()
                    << l1r2.second->front()->toString();
          auto l2r1 = get_matching_l2r1_splits(vg, l1r2);
          if (!l2r1.has_value()) {
            continue;
          }
    
          std::cerr << "Found\n";
    
          auto l1r2_first_outputs = new_graph.outputGroups(l1r2.first);
          auto l1r2_second_outputs = new_graph.outputGroups(l1r2.second);
    
          auto l2r1_first_outputs = new_graph.outputGroups(l2r1->first);
          auto l2r1_second_outputs = new_graph.outputGroups(l2r1->second);
    
          groups_to_map.emplace_back(
              l1r2_first_outputs.at(0), l2r1_second_outputs.at(0));
          groups_to_map.emplace_back(
              l1r2_second_outputs.at(0), l2r1_second_outputs.at(1));
          groups_to_map.emplace_back(
              l1r2_second_outputs.at(1), l2r1_first_outputs.at(1));
        }
      }
    
      for (const auto& [vg1, vg2] : groups_to_map) {
        std::cerr << "Mapping " << nvfuser::toString(vg1) << ", "
                  << vg1->front()->toString() << " and " << nvfuser::toString(vg2)
                  << ", " << vg2->front()->toString() << "\n";
        new_graph.mapVals(vg1->front(), vg2->front());
      }
    
      return new_graph;
    Test Coverage

    While multiple tests are added, it would be beneficial to add more edge cases and ensure that the tests cover a wide range of scenarios to validate the correctness of the mapAlmostExactSplits function.

          id_model.idGraph(IdMappingMode::LOOP).toGroup(tv2->axis(-1)));
      EXPECT_TRUE(promotion_id->isBroadcast())
          << "Should not be promoted a non-broadcast ID: "
          << promotion_id->toString();
    }
    
    TEST_F(IdModelTest, AlmostExactSplitGraph1) {
      auto fusion_ptr = std::make_unique<Fusion>();
      auto& fusion = *fusion_ptr;
      FusionGuard fg(fusion_ptr.get());
    
      auto tv0 = makeContigConcreteTensor({3 * 4 * 5});
      fusion.addInput(tv0);
    
      auto tv1 = set(tv0);
    
      auto tv2 = reshape(tv1, {3 * 4 * 5}, {3, 4, 5});
      // Outer split 3*4*5 by 3
      // Outer split 4*5 by 4
    
      fusion.addOutput(tv2);
    
      tv0->split(0, 5);
      // [3*4, 5]
      tv0->split(0, 4);
      // [3, 4, 5]
    
      fusion.print();
    
      IdModel id_model(&fusion);
    
      std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString();
    
      auto almost_exact_split_graph =
          mapAlmostExactSplits(id_model.maybeBuildGraph(IdMappingMode::EXACT));
    
      std::cerr << almost_exact_split_graph.toString();
    
      scheduler_tools::scheduleLoopDomainsLike(
          {tv1, tv2},
          tv0->getLoopDomain(),
          /*update_loop_domain_only=*/true,
          &almost_exact_split_graph);
    
      fusion.print();
    }
    
    TEST_F(IdModelTest, AlmostExactSplitGraph2) {
      auto fusion_ptr = std::make_unique<Fusion>();
      auto& fusion = *fusion_ptr;
      FusionGuard fg(fusion_ptr.get());
    
      auto tv0 = makeContigConcreteTensor({3 * 4 * 5});
      fusion.addInput(tv0);
    
      auto tv1 = set(tv0);
    
      auto tv2 = reshape(tv1, {3 * 4 * 5}, {3, 4, 5});
      // Outer split 3*4*5 by 3
      // Outer split 4*5 by 4
    
      fusion.addOutput(tv2);
    
      tv0->split(0, 5);
      // [3*4, 5]
      tv0->split(0, 4);
      // [3, 4, 5]
    
      tv0->merge(1, 2);
    
      fusion.print();
    
      IdModel id_model(&fusion);
    
      std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString();
    
      auto almost_exact_split_graph =
          mapAlmostExactSplits(id_model.maybeBuildGraph(IdMappingMode::EXACT));
    
      std::cerr << almost_exact_split_graph.toString();
    
      scheduler_tools::scheduleLoopDomainsLike(
          {tv1, tv2},
          tv0->getLoopDomain(),
          /*update_loop_domain_only=*/true,
          &almost_exact_split_graph);
    
      fusion.print();
    }
    
    TEST_F(IdModelTest, AlmostExactSplitGraph3) {
      auto fusion_ptr = std::make_unique<Fusion>();
      auto& fusion = *fusion_ptr;
      FusionGuard fg(fusion_ptr.get());
    
      auto tv0 = makeContigConcreteTensor({3 * 4 * 5});
      fusion.addInput(tv0);
    
      auto tv1 = set(tv0);
    
      auto tv2 = reshape(tv1, {3 * 4 * 5}, {3, 4, 5});
      // Outer split 3*4*5 by 3
      // Outer split 4*5 by 4
    
      fusion.addOutput(tv2);
    
      tv0->split(0, 5);
      // [3*4, 5]
      tv0->split(0, 4);
      // [3, 4, 5]
    
      tv0->merge(1, 2);
    
      tv1->split(0, 5);
    
      fusion.print();
    
      IdModel id_model(&fusion);
    
      std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString();
    
      auto almost_exact_split_graph =
          mapAlmostExactSplits(id_model.maybeBuildGraph(IdMappingMode::EXACT));
    
      std::cerr << almost_exact_split_graph.toString();
    
      scheduler_tools::scheduleLoopDomainsLike(
          {tv1, tv2},
          tv0->getLoopDomain(),
          /*update_loop_domain_only=*/true,
          &almost_exact_split_graph);
    
      fusion.print();
    }
    
    TEST_F(IdModelTest, AlmostExactSplitGraph4) {
      auto fusion_ptr = std::make_unique<Fusion>();
      auto& fusion = *fusion_ptr;
      FusionGuard fg(fusion_ptr.get());
    
      auto tv0 = makeContigConcreteTensor({6, 5});
      fusion.addInput(tv0);
    
      auto tv1 = set(tv0);
    
      auto tv2 = reshape(tv1, {6, 5}, {30});
      // Merge 6, 5 -> 30
    
      auto tv3 = set(tv2);
    
      fusion.addOutput(tv3);
    
      tv0->outer_split(0, 2);
      // [2, 3, 5]
    
      tv2->outer_split(0, 2);
      // [2, 15]
      tv2->outer_split(1, 3);
      // [2, 3, 5]
    
      fusion.print();
    
      IdModel id_model(&fusion);
    
      std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString();
    
      auto graph = id_model.maybeBuildGraph(IdMappingMode::EXACT);
    
      for (const auto i : arange(tv0->nDims())) {
        graph.mapVals(tv0->axis(i), tv2->axis(i));
      }
    
      std::cerr << graph.toString();
    
      scheduler_tools::scheduleLoopDomainsLike(
          {tv1, tv2, tv3},
          tv0->getLoopDomain(),
          /*update_loop_domain_only=*/true,
          &graph);
    
      fusion.print();
    }
    
    TEST_F(IdModelTest, AlmostExactSplitGraph5) {
      auto fusion_ptr = std::make_unique<Fusion>();
      auto& fusion = *fusion_ptr;
      FusionGuard fg(fusion_ptr.get());
    
      int64_t h = 60;
      int64_t a = 6;
      int64_t d = 2;
    
      auto tv0 = makeContigConcreteTensor({h});
      fusion.addInput(tv0);
    
      auto tv1 = reshape(tv0, {h}, {a, h / a});
      auto tv2 = set(tv1);
      auto tv3 = reshape(tv2, {a, h / a}, {h});
    
      fusion.addOutput(tv3);
    
      tv0->outer_split(0, d);
      // [d, h/d]
      tv0->split(1, h / a);
      // [d, a/d, h/a]
    
      tv1->outer_split(0, d);
      // [d, a/d, h/a]
    
      fusion.print();
    
      IdModel id_model(&fusion);
    
      std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString();
    
      auto graph = id_model.maybeBuildGraph(IdMappingMode::EXACT);
    
      for (const auto i : arange(tv0->nDims())) {
        graph.mapVals(tv0->axis(i), tv1->axis(i));
      }
    
      graph.mapVals(tv0->getLogicalDomain().at(0), tv3->getLogicalDomain().at(0));
    
      std::cerr << graph.toString();
    
      scheduler_tools::scheduleLoopDomainsLike(
          {tv1, tv2, tv3},
          tv0->getLoopDomain(),
          /*update_loop_domain_only=*/true,
          &graph);
    
      fusion.print();

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant