diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 8ca3804d082..ea7eb5db42a 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -1271,4 +1271,143 @@ Val* IdModel::getLoopIndexVariable( return getLoopIndexVariable(loop_group, circular_buffer_loop_stage); } +// https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations +ValGraph mapAlmostExactSplits(const ValGraph& graph) { + auto new_graph = graph; + + // vg: I0 + auto get_l1r2_splits = + [&new_graph]( + const ValGroup& vg) -> std::vector> { + std::vector> l1_r2_splits; + + if (!new_graph.hasUses(vg)) { + return {}; + } + + for (const ExprGroup& use_of_vg : new_graph.getUses(vg)) { + auto split_of_vg = dynamic_cast(use_of_vg->front()); + if (split_of_vg == nullptr) { + continue; + } + + // mn + const ValGroup& inner_group = new_graph.toGroup(split_of_vg->inner()); + + if (!new_graph.hasUses(inner_group)) { + return {}; + } + + for (const ExprGroup& use_of_inner_group : + new_graph.getUses(inner_group)) { + auto split_of_inner_group = + dynamic_cast(use_of_inner_group->front()); + if (split_of_inner_group == nullptr) { + continue; + } + + // This split needs to be divisible + auto extent = split_of_inner_group->in()->extent(); + auto factor = split_of_inner_group->factor(); + if (extent->isConstScalar() && factor->isConstScalar() && + (extent->evaluate().as() % + factor->evaluate().as() != + 0)) { + continue; + } + + l1_r2_splits.emplace_back(use_of_vg, use_of_inner_group); + + std::cerr << "L1R2 found: " << split_of_vg->toString() + << split_of_inner_group->toString(); + } + } + + return l1_r2_splits; + }; + + auto get_matching_l2r1_splits = + [&new_graph]( + const ValGroup& vg, const std::pair& l1_r2) + -> std::optional> { + auto m = l1_r2.second->front()->as()->outer()->extent(); + auto n = l1_r2.second->front()->as()->inner()->extent(); + + for (const ExprGroup& use_of_vg : new_graph.getUses(vg)) { + auto split_of_vg = dynamic_cast(use_of_vg->front()); + if (split_of_vg == nullptr) { + continue; + } + + if (!split_of_vg->inner()->extent()->sameAs(n)) { + continue; + } + + // I0/n + const ValGroup& outer_group = new_graph.toGroup(split_of_vg->outer()); + + if (!new_graph.hasUses(outer_group)) { + return {}; + } + + for (const ExprGroup& use_of_outer_group : + new_graph.getUses(outer_group)) { + auto split_of_outer_group = + dynamic_cast(use_of_outer_group->front()); + if (split_of_outer_group == nullptr) { + continue; + } + + if (!split_of_outer_group->inner()->extent()->sameAs(m)) { + continue; + } + + std::cerr << "Matching L2R1 found: " << split_of_vg->toString() + << split_of_outer_group->toString(); + return std::make_pair(use_of_vg, use_of_outer_group); + } + } + + return std::nullopt; + }; + + std::vector> groups_to_map; + + for (const ValGroup& vg : new_graph.disjointValSets().disjointSets()) { + const auto all_l1r2_splits = get_l1r2_splits(vg); + for (const auto& l1r2 : all_l1r2_splits) { + std::cerr << "L1R2: " << l1r2.first->front()->toString() + << l1r2.second->front()->toString(); + auto l2r1 = get_matching_l2r1_splits(vg, l1r2); + if (!l2r1.has_value()) { + continue; + } + + std::cerr << "Found\n"; + + auto l1r2_first_outputs = new_graph.outputGroups(l1r2.first); + auto l1r2_second_outputs = new_graph.outputGroups(l1r2.second); + + auto l2r1_first_outputs = new_graph.outputGroups(l2r1->first); + auto l2r1_second_outputs = new_graph.outputGroups(l2r1->second); + + groups_to_map.emplace_back( + l1r2_first_outputs.at(0), l2r1_second_outputs.at(0)); + groups_to_map.emplace_back( + l1r2_second_outputs.at(0), l2r1_second_outputs.at(1)); + groups_to_map.emplace_back( + l1r2_second_outputs.at(1), l2r1_first_outputs.at(1)); + } + } + + for (const auto& [vg1, vg2] : groups_to_map) { + std::cerr << "Mapping " << nvfuser::toString(vg1) << ", " + << vg1->front()->toString() << " and " << nvfuser::toString(vg2) + << ", " << vg2->front()->toString() << "\n"; + new_graph.mapVals(vg1->front(), vg2->front()); + } + + return new_graph; +} + } // namespace nvfuser diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 26d849db5b4..b32fd53eac6 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -363,4 +363,6 @@ std::unordered_map updateValGroupIdMap( const std::unordered_map& stale_map, ValGraph& new_graph); +ValGraph mapAlmostExactSplits(const ValGraph& graph); + } // namespace nvfuser diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index e1421059a7c..cd40404c3fa 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -180,14 +180,19 @@ class LoopDomainScheduler { public: LoopDomainScheduler( std::vector ref_loop_dom, - bool update_loop_domain_only = false) + bool update_loop_domain_only = false, + const ValGraph* scheduling_graph = nullptr) : ref_loop_dom_(std::move(ref_loop_dom)), - update_loop_domain_only_(update_loop_domain_only) { + update_loop_domain_only_(update_loop_domain_only), + graph_(scheduling_graph) { NVF_ERROR(!ref_loop_dom_.empty()); - Fusion* fusion = ref_loop_dom_.front()->fusion(); - id_model_ = std::make_unique(fusion, /*build_graphs=*/false); - id_model_->buildExactGraph(); + if (graph_ == nullptr) { + Fusion* fusion = ref_loop_dom_.front()->fusion(); + id_model_ = std::make_unique(fusion, /*build_graphs=*/false); + id_model_->buildExactGraph(); + graph_ = &(id_model_->idGraph(IdMappingMode::EXACT)); + } ref_id_groups_ = graph().toGroups(ref_loop_dom_); @@ -205,8 +210,9 @@ class LoopDomainScheduler { void schedule(TensorView* tv) const; private: - ValGraph& graph() const { - return id_model_->idGraph(IdMappingMode::EXACT); + const ValGraph& graph() const { + NVF_ERROR(graph_ != nullptr); + return *graph_; } ValGraphBFS::ExprPath getReplayPath(TensorView* tv) const; @@ -250,6 +256,7 @@ class LoopDomainScheduler { // updates it to make it look like the given reference loop domain bool update_loop_domain_only_ = false; std::unique_ptr id_model_; + const ValGraph* graph_ = nullptr; ValGroups ref_id_groups_; ValGroups all_ancestors_of_ref_; }; @@ -477,12 +484,13 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { void scheduleLoopDomainsLike( const std::vector& tvs, const std::vector& ref_loop_dom, - bool update_loop_domain_only) { + bool update_loop_domain_only, + const ValGraph* graph) { if (tvs.empty()) { return; } - LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only); + LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only, graph); for (auto tv : tvs) { // Loop domain of fusion inputs should have no meaning, diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index cedfbb06d19..f8ee1405b6a 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -31,7 +31,8 @@ namespace scheduler_tools { void scheduleLoopDomainsLike( const std::vector& tvs, const std::vector& ref_loop_dom, - bool update_loop_domain_only = false); + bool update_loop_domain_only = false, + const ValGraph* graph = nullptr); // Replay a transform expr on the loop domain of each of the given // tensors. If the replay direction is specified, the expr is replayed diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index dea9ffd7e1f..8dd42d1ff2e 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -3129,4 +3130,232 @@ TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) { << promotion_id->toString(); } +TEST_F(IdModelTest, AlmostExactSplitGraph1) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({3 * 4 * 5}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = reshape(tv1, {3 * 4 * 5}, {3, 4, 5}); + // Outer split 3*4*5 by 3 + // Outer split 4*5 by 4 + + fusion.addOutput(tv2); + + tv0->split(0, 5); + // [3*4, 5] + tv0->split(0, 4); + // [3, 4, 5] + + fusion.print(); + + IdModel id_model(&fusion); + + std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString(); + + auto almost_exact_split_graph = + mapAlmostExactSplits(id_model.maybeBuildGraph(IdMappingMode::EXACT)); + + std::cerr << almost_exact_split_graph.toString(); + + scheduler_tools::scheduleLoopDomainsLike( + {tv1, tv2}, + tv0->getLoopDomain(), + /*update_loop_domain_only=*/true, + &almost_exact_split_graph); + + fusion.print(); +} + +TEST_F(IdModelTest, AlmostExactSplitGraph2) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({3 * 4 * 5}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = reshape(tv1, {3 * 4 * 5}, {3, 4, 5}); + // Outer split 3*4*5 by 3 + // Outer split 4*5 by 4 + + fusion.addOutput(tv2); + + tv0->split(0, 5); + // [3*4, 5] + tv0->split(0, 4); + // [3, 4, 5] + + tv0->merge(1, 2); + + fusion.print(); + + IdModel id_model(&fusion); + + std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString(); + + auto almost_exact_split_graph = + mapAlmostExactSplits(id_model.maybeBuildGraph(IdMappingMode::EXACT)); + + std::cerr << almost_exact_split_graph.toString(); + + scheduler_tools::scheduleLoopDomainsLike( + {tv1, tv2}, + tv0->getLoopDomain(), + /*update_loop_domain_only=*/true, + &almost_exact_split_graph); + + fusion.print(); +} + +TEST_F(IdModelTest, AlmostExactSplitGraph3) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({3 * 4 * 5}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = reshape(tv1, {3 * 4 * 5}, {3, 4, 5}); + // Outer split 3*4*5 by 3 + // Outer split 4*5 by 4 + + fusion.addOutput(tv2); + + tv0->split(0, 5); + // [3*4, 5] + tv0->split(0, 4); + // [3, 4, 5] + + tv0->merge(1, 2); + + tv1->split(0, 5); + + fusion.print(); + + IdModel id_model(&fusion); + + std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString(); + + auto almost_exact_split_graph = + mapAlmostExactSplits(id_model.maybeBuildGraph(IdMappingMode::EXACT)); + + std::cerr << almost_exact_split_graph.toString(); + + scheduler_tools::scheduleLoopDomainsLike( + {tv1, tv2}, + tv0->getLoopDomain(), + /*update_loop_domain_only=*/true, + &almost_exact_split_graph); + + fusion.print(); +} + +TEST_F(IdModelTest, AlmostExactSplitGraph4) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({6, 5}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = reshape(tv1, {6, 5}, {30}); + // Merge 6, 5 -> 30 + + auto tv3 = set(tv2); + + fusion.addOutput(tv3); + + tv0->outer_split(0, 2); + // [2, 3, 5] + + tv2->outer_split(0, 2); + // [2, 15] + tv2->outer_split(1, 3); + // [2, 3, 5] + + fusion.print(); + + IdModel id_model(&fusion); + + std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString(); + + auto graph = id_model.maybeBuildGraph(IdMappingMode::EXACT); + + for (const auto i : arange(tv0->nDims())) { + graph.mapVals(tv0->axis(i), tv2->axis(i)); + } + + std::cerr << graph.toString(); + + scheduler_tools::scheduleLoopDomainsLike( + {tv1, tv2, tv3}, + tv0->getLoopDomain(), + /*update_loop_domain_only=*/true, + &graph); + + fusion.print(); +} + +TEST_F(IdModelTest, AlmostExactSplitGraph5) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + int64_t h = 60; + int64_t a = 6; + int64_t d = 2; + + auto tv0 = makeContigConcreteTensor({h}); + fusion.addInput(tv0); + + auto tv1 = reshape(tv0, {h}, {a, h / a}); + auto tv2 = set(tv1); + auto tv3 = reshape(tv2, {a, h / a}, {h}); + + fusion.addOutput(tv3); + + tv0->outer_split(0, d); + // [d, h/d] + tv0->split(1, h / a); + // [d, a/d, h/a] + + tv1->outer_split(0, d); + // [d, a/d, h/a] + + fusion.print(); + + IdModel id_model(&fusion); + + std::cerr << id_model.maybeBuildGraph(IdMappingMode::EXACT).toString(); + + auto graph = id_model.maybeBuildGraph(IdMappingMode::EXACT); + + for (const auto i : arange(tv0->nDims())) { + graph.mapVals(tv0->axis(i), tv1->axis(i)); + } + + graph.mapVals(tv0->getLogicalDomain().at(0), tv3->getLogicalDomain().at(0)); + + std::cerr << graph.toString(); + + scheduler_tools::scheduleLoopDomainsLike( + {tv1, tv2, tv3}, + tv0->getLoopDomain(), + /*update_loop_domain_only=*/true, + &graph); + + fusion.print(); +} + } // namespace nvfuser