From 7399d1d2194e1ceeb4485b38ddcd9ec0a2958d12 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 14 Jul 2023 15:06:37 -0400 Subject: [PATCH 1/9] Add test for concretization in repro of #418 This does not test execution, which is hitting a separate issue. --- test/test_dynamic_transform.cpp | 45 +++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 7530fba589e..1520bbc6346 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -1070,4 +1070,49 @@ TEST_F(NVFuserTest, FusionDynamicEmptyCat2_CUDA) { EXPECT_EQ(output_def->input(0), seg_fusion->inputs()[0]); } +// Repro of https://github.com/NVIDIA/Fuser/issues/418 +TEST_F(NVFuserTest, DynamicTransformIssue418Concretization_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(4); + fusion->addInput(tv0); + auto s0 = IrBuilder::create(DataType::Int); + fusion->addInput(s0); + + auto v00 = tv0->axis(0)->extent(); + auto v01 = tv0->axis(1)->extent(); + auto v02 = tv0->axis(2)->extent(); + auto v03 = tv0->axis(3)->extent(); + + auto tv1 = reshape(tv0, {v00, div(v01, s0), s0, v02, v03}); + auto vm = variance_mean(tv1, {2, 3, 4}, 0, true); + fusion->addOutput(vm.mean); + fusion->addOutput(vm.var); + + { + ExpressionEvaluator expr_eval; + + expr_eval.bind(tv0->axis(0)->extent(), 256L); + expr_eval.bind(tv0->axis(1)->extent(), 128L); + expr_eval.bind(tv0->axis(2)->extent(), 28L); + expr_eval.bind(tv0->axis(3)->extent(), 28L); + expr_eval.bind(s0, 4L); + + auto initial_info = DynamicTransform::getInitialInfo(fusion.get()); + auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); + + TORCH_CHECK( + info.getReshapeTransforms().size() == 1, + "Expected to have one reshape transform: ", + info.toString()); + + DynamicTransform::concretizeFusion(fusion.get(), &info); + + TORCH_CHECK( + !fusion->hasDynamicTransform(), + "Expected to have no dynamic transform"); + } +} + } // namespace nvfuser From aa5e8d2e1d9f15ed29ad7607674a8e333f35f8e7 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 14 Jul 2023 15:07:18 -0400 Subject: [PATCH 2/9] Explicitly mutate all expr outputs in concretization. It would be better to address this in StmtSort/IterVisitor --- csrc/dynamic_transform.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 27f91a66c68..46a98877304 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -457,8 +457,22 @@ void DynamicTransformConcretizer::concretize() { // Finally, propagate concretized domains auto all_stmts = StmtSort::getStmts(info_->fusion()); + std::unordered_set visited_tvs; for (auto tv : ir_utils::filterByType(all_stmts)) { - mutate(tv); + if (tv->definition()) { + for (auto outp : + ir_utils::filterByType(tv->definition()->outputs())) { + if (visited_tvs.find(outp) == visited_tvs.end()) { + mutate(outp); + visited_tvs.insert(tv); + } + } + } else { + if (visited_tvs.find(tv) == visited_tvs.end()) { + mutate(tv); + visited_tvs.insert(tv); + } + } } } From cb48d7cefbbe1ec4c53fc5a45f5d695364f72e70 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 14 Jul 2023 16:05:32 -0400 Subject: [PATCH 3/9] Add traverse_siblings option to IterVisitor and StmtSort --- csrc/dynamic_transform.cpp | 22 ++++-------- csrc/iter_visitor.cpp | 72 +++++++++++++++++++++++++++++--------- csrc/iter_visitor.h | 30 +++++++++++----- csrc/scheduler/utils.cpp | 2 +- 4 files changed, 85 insertions(+), 41 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 46a98877304..c4b8cbbc997 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -456,23 +456,13 @@ void DynamicTransformConcretizer::concretize() { concretizeEmptyExtents(); // Finally, propagate concretized domains - auto all_stmts = StmtSort::getStmts(info_->fusion()); - std::unordered_set visited_tvs; + auto all_stmts = StmtSort::getStmts( + info_->fusion(), + /*traverse_members*/ false, + /*traverse_attributes*/ false, + /*traverse_siblings*/ true); for (auto tv : ir_utils::filterByType(all_stmts)) { - if (tv->definition()) { - for (auto outp : - ir_utils::filterByType(tv->definition()->outputs())) { - if (visited_tvs.find(outp) == visited_tvs.end()) { - mutate(outp); - visited_tvs.insert(tv); - } - } - } else { - if (visited_tvs.find(tv) == visited_tvs.end()) { - mutate(tv); - visited_tvs.insert(tv); - } - } + mutate(tv); } } diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index e71c86fda6e..99653ada777 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -142,7 +142,8 @@ void IterVisitor::traverseBetween( const std::vector& to, bool traverse_all_paths, bool traverse_into_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { FusionGuard fg(fusion); std::unordered_set visited; @@ -227,6 +228,17 @@ void IterVisitor::traverseBetween( // If we don't want to retraverse, remove nodes we already visisted. remove_visited(next_stmts, visited); } + + if (traverse_siblings) { + for (auto next_val : ir_utils::filterByType(next_stmts)) { + for (auto sib : ir_utils::siblingValsOf(next_val)) { + if (visited.find(sib) == visited.end()) { + next_stmts.push_back(sib); + } + } + } + } + if (next_stmts.empty()) { // If there's nothing to visit because it was all already visited, mark // to process @@ -263,14 +275,16 @@ void IterVisitor::traverseTo( const std::vector& to, bool traverse_all_paths, bool traverse_into_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { traverseBetween( fusion, {}, to, traverse_all_paths, traverse_into_members, - traverse_attributes); + traverse_attributes, + traverse_siblings); } void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) { @@ -846,19 +860,25 @@ void StmtSort::handle(Statement* stmt) { std::vector StmtSort::getExprs( Fusion* fusion, bool traverse_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { auto terminating_outputs = fusion->getTerminatingOutputs(); return StmtSort::getExprs( - fusion, terminating_outputs, traverse_members, traverse_attributes); + fusion, + terminating_outputs, + traverse_members, + traverse_attributes, + traverse_siblings); } std::vector StmtSort::getExprs( Fusion* fusion, const std::vector& to, bool traverse_members, - bool traverse_attributes) { - auto stmts = - StmtSort::getStmts(fusion, to, traverse_members, traverse_attributes); + bool traverse_attributes, + bool traverse_siblings) { + auto stmts = StmtSort::getStmts( + fusion, to, traverse_members, traverse_attributes, traverse_siblings); auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); std::vector exprs(filter.begin(), filter.end()); return exprs; @@ -869,9 +889,15 @@ std::vector StmtSort::getExprsBetween( const std::vector& from, const std::vector& to, bool traverse_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { auto stmts = StmtSort::getStmtsBetween( - fusion, from, to, traverse_members, traverse_attributes); + fusion, + from, + to, + traverse_members, + traverse_attributes, + traverse_siblings); auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); std::vector exprs(filter.begin(), filter.end()); return exprs; @@ -880,19 +906,31 @@ std::vector StmtSort::getExprsBetween( std::vector StmtSort::getStmts( Fusion* fusion, bool traverse_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { auto terminating_outputs = fusion->getTerminatingOutputs(); return StmtSort::getStmts( - fusion, terminating_outputs, traverse_members, traverse_attributes); + fusion, + terminating_outputs, + traverse_members, + traverse_attributes, + traverse_siblings); } std::vector StmtSort::getStmts( Fusion* fusion, const std::vector& to, bool traverse_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { StmtSort es; - es.traverseTo(fusion, to, false, traverse_members, traverse_attributes); + es.traverseTo( + fusion, + to, + false, + traverse_members, + traverse_attributes, + traverse_siblings); return es.stmts; } @@ -901,7 +939,8 @@ std::vector StmtSort::getStmtsBetween( const std::vector& from, const std::vector& to, bool traverse_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { StmtSort es; es.traverseBetween( fusion, @@ -909,7 +948,8 @@ std::vector StmtSort::getStmtsBetween( to, false, traverse_members, - traverse_attributes); + traverse_attributes, + traverse_siblings); return es.stmts; } diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index 425463cbc36..a16a23e74ac 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -94,12 +94,16 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { //! \param traverse_attributes When true, traverse into expr //! attributes. Note that attributes of template type Attribute are //! not traversed as there's no dispatch support. + //! \param traverse_siblings When true, traverse all outputs of + //! active multi-output expressions, even if those Expr outputs are not used + //! in paths to Fusion outputs. void traverseTo( Fusion* fusion, const std::vector& to, bool traverse_all_paths = false, bool traverse_into_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); //! Traverses nodes in Fusion from inputs in topological order to "to". i.e. //! from inputs towards outputs. @@ -117,13 +121,17 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { //! \param traverse_attributes When true, traverse into expr //! attributes. Note that attributes of template type Attribute are //! not traversed as there's no dispatch support. + //! \param traverse_siblings When true, traverse all outputs of + //! active multi-output expressions, even if those Expr outputs are not used + //! in paths to Fusion outputs. void traverseBetween( Fusion* fusion, const std::unordered_set& from, const std::vector& to, bool traverse_all_paths = false, bool traverse_into_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Iterates from terminating outputs registered with the fusion. Terminating // means value is not used to generate any other value used in producing @@ -299,14 +307,16 @@ class StmtSort : public IterVisitor { static std::vector getStmts( Fusion* fusion, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Returns ordered Statements required to produce 'to', including 'to'. static std::vector getStmts( Fusion* fusion, const std::vector& to, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Returns ordered Statements required to produce from, including from. // Stops traversal once hiting any Statements in to. Includes Statements in @@ -330,20 +340,23 @@ class StmtSort : public IterVisitor { const std::vector& from, const std::vector& to, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Same as getStmts version but filters to only return the Expr*s static std::vector getExprs( Fusion* fusion, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Same as getStmts version but filters to only return the Expr*s static std::vector getExprs( Fusion* fusion, const std::vector& to, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Same as getStmts version but filters to only return the Expr*s static std::vector getExprsBetween( @@ -351,7 +364,8 @@ class StmtSort : public IterVisitor { const std::vector& from, const std::vector& to, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); }; class TORCH_CUDA_CU_API InputsOf : public IterVisitor { diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 89d79ea91d4..b13b633df4c 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1058,7 +1058,7 @@ IterDomain* projectIdToRoot( } auto replay_exprs = - StmtSort::getExprs(tv->fusion(), {reference_id}, false, false); + StmtSort::getExprs(tv->fusion(), {reference_id}, false, false, false); if (replay_exprs.empty()) { return reference_id; } From d4987e64a0b47797424821064661e2cf5e3c88c6 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 17 Jul 2023 08:16:36 -0400 Subject: [PATCH 4/9] Fix segfault from modifying vector while looping --- csrc/iter_visitor.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 99653ada777..b45b3382378 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -230,13 +230,19 @@ void IterVisitor::traverseBetween( } if (traverse_siblings) { + // Add unvisited siblings to next_stmts + std::vector unvisited_sibs; for (auto next_val : ir_utils::filterByType(next_stmts)) { for (auto sib : ir_utils::siblingValsOf(next_val)) { if (visited.find(sib) == visited.end()) { - next_stmts.push_back(sib); + // Push to separate vector so that we don't modify next_stmts + // while looping + unvisited_sibs.push_back(sib); } } } + next_stmts.insert( + next_stmts.end(), unvisited_sibs.begin(), unvisited_sibs.end()); } if (next_stmts.empty()) { From cc388adfcf84a0a790e06e454595540f710826b5 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 17 Jul 2023 14:42:11 -0400 Subject: [PATCH 5/9] Rename overloads getExprsTo, getStmtsTo This prevents accidental use of the wrong overload, wherein a vector of Vals might be converted to bool. --- csrc/compute_at_map.cpp | 2 +- csrc/device_lower/analysis/divisible_split.cpp | 2 +- csrc/device_lower/pass/allocation.cpp | 2 +- csrc/device_lower/validation.cpp | 2 +- csrc/fusion.cpp | 2 +- csrc/fusion_segmenter.cpp | 2 +- csrc/ir/cloner.cpp | 2 +- csrc/ir/utils.cpp | 4 ++-- csrc/iter_visitor.cpp | 16 ++++++++-------- csrc/iter_visitor.h | 4 ++-- csrc/partial_split_map.cpp | 2 +- csrc/scheduler/utils.cpp | 9 ++++----- csrc/transform_iter.cpp | 8 ++++---- 13 files changed, 28 insertions(+), 29 deletions(-) diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index 6ca15f84c7f..3cf5c67733f 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -605,7 +605,7 @@ void IterDomainGraph::build(Fusion* fusion) { // Grab all the rfactor ids. for (auto consumer_tv : all_consumer_tvs) { - auto exprs = StmtSort::getExprs( + auto exprs = StmtSort::getExprsTo( fusion, {consumer_tv->getMaybeRFactorDomain().begin(), consumer_tv->getMaybeRFactorDomain().end()}); diff --git a/csrc/device_lower/analysis/divisible_split.cpp b/csrc/device_lower/analysis/divisible_split.cpp index 875df3e3414..75d344a4dbd 100644 --- a/csrc/device_lower/analysis/divisible_split.cpp +++ b/csrc/device_lower/analysis/divisible_split.cpp @@ -38,7 +38,7 @@ std::unordered_set getAllDivisibleSplits( // Take the view transformations and add all the splits. Those splits are // the only divisible splits. auto view_exprs = - StmtSort::getExprs(fusion, {rfactor_dom.begin(), rfactor_dom.end()}); + StmtSort::getExprsTo(fusion, {rfactor_dom.begin(), rfactor_dom.end()}); auto split_exprs = ir_utils::filterByType(view_exprs); all_divisible_splits.insert(split_exprs.begin(), split_exprs.end()); } diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index c06ca9c4022..f04327d2060 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -216,7 +216,7 @@ class AllocationInserter : public kir::ExprMutator { [](IterDomain* dom) { return dom->as(); }); // Get all exprs involved in generating the allocation IDs - auto exprs = StmtSort::getExprs(tv->fusion(), start_vals); + auto exprs = StmtSort::getExprsTo(tv->fusion(), start_vals); // Get the halo extent if found auto getExtent = [this](IterDomain* id) { diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index c2c2250d0b3..1dc550d0b0c 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -828,7 +828,7 @@ void validatePartialSplit(Fusion* fusion) { auto range_info = getLiveRangeOffsets(fusion); for (auto tv : ir_utils::allTvs(fusion)) { - auto exprs = StmtSort::getExprs( + auto exprs = StmtSort::getExprsTo( tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // When the start and stop offsets are not zero, make sure the diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 8fa3605d498..2affce69e7b 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -490,7 +490,7 @@ void Fusion::printMath(bool from_outputs_only) { leaf_vals.push_back(val); } } - exprs_for_print = StmtSort::getExprs(this, leaf_vals); + exprs_for_print = StmtSort::getExprsTo(this, leaf_vals); } debug() << "\n%kernel_math {\n"; diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 3b4458ea8f9..d8687968657 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -3600,7 +3600,7 @@ void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) { group->input_vals = IterVisitor::getInputsTo(group->inputs()); // Grab all expressions needed to produce to_visit - auto input_exprs = StmtSort::getExprs(completeFusion(), to_visit); + auto input_exprs = StmtSort::getExprsTo(completeFusion(), to_visit); // Insert those expressions at the beginning of the group group->exprs_.insert( diff --git a/csrc/ir/cloner.cpp b/csrc/ir/cloner.cpp index 66295857303..61082a6391a 100644 --- a/csrc/ir/cloner.cpp +++ b/csrc/ir/cloner.cpp @@ -116,7 +116,7 @@ Statement* RecomputeTv::handle(const TensorDomain* td) { // Make sure to recompute the history of the iteration domains, explicitly go // through the expressions and send them to IrCloner. auto exprs = - StmtSort::getExprs(fusion_, {td->leaf().begin(), td->leaf().end()}); + StmtSort::getExprsTo(fusion_, {td->leaf().begin(), td->leaf().end()}); for (auto expr : exprs) { IrCloner::handle(expr); diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 8f3be4f934c..71287011070 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -463,7 +463,7 @@ class ValReplacementMutator : private OptOutMutator { // typically not used by anything else. If we don't grab that count, then it // would be a tensorview that doesn't get updated extents. Therefore, first // grab all leaves towards outputs and grab stmts from there. - auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true, true); + auto stmts = StmtSort::getStmtsTo(fusion, allLeafOuts(fusion), true, true); // Some fusions, such as standalone rand_like, can have disconnected DAG, so // we need some mechanism to make sure our replacement set is as complete as @@ -481,7 +481,7 @@ class ValReplacementMutator : private OptOutMutator { more.emplace_back(v); } } - auto more_stmts = StmtSort::getStmts(fusion, more, true, true); + auto more_stmts = StmtSort::getStmtsTo(fusion, more, true, true); more_stmts.insert(more_stmts.end(), stmts.begin(), stmts.end()); for (auto stmt : more_stmts) { diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index b45b3382378..d43c93ff411 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -455,7 +455,7 @@ void BackwardVisitor::traverseTo( } auto vals = AllVals::get(fusion, from); - auto exprs = StmtSort::getExprs(fusion, from); + auto exprs = StmtSort::getExprsTo(fusion, from); { size_t pos = 0; @@ -869,7 +869,7 @@ std::vector StmtSort::getExprs( bool traverse_attributes, bool traverse_siblings) { auto terminating_outputs = fusion->getTerminatingOutputs(); - return StmtSort::getExprs( + return StmtSort::getExprsTo( fusion, terminating_outputs, traverse_members, @@ -877,13 +877,13 @@ std::vector StmtSort::getExprs( traverse_siblings); } -std::vector StmtSort::getExprs( +std::vector StmtSort::getExprsTo( Fusion* fusion, const std::vector& to, bool traverse_members, bool traverse_attributes, bool traverse_siblings) { - auto stmts = StmtSort::getStmts( + auto stmts = StmtSort::getStmtsTo( fusion, to, traverse_members, traverse_attributes, traverse_siblings); auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); std::vector exprs(filter.begin(), filter.end()); @@ -915,7 +915,7 @@ std::vector StmtSort::getStmts( bool traverse_attributes, bool traverse_siblings) { auto terminating_outputs = fusion->getTerminatingOutputs(); - return StmtSort::getStmts( + return StmtSort::getStmtsTo( fusion, terminating_outputs, traverse_members, @@ -923,7 +923,7 @@ std::vector StmtSort::getStmts( traverse_siblings); } -std::vector StmtSort::getStmts( +std::vector StmtSort::getStmtsTo( Fusion* fusion, const std::vector& to, bool traverse_members, @@ -983,11 +983,11 @@ std::vector InputsOf::outputs( bool DeadCodeRemover::run() { // First we build a set of all live Statements so that we can detect dead // branches. - for (auto stmt : StmtSort::getStmts(fusion_, fusion_->outputs())) { + for (auto stmt : StmtSort::getStmtsTo(fusion_, fusion_->outputs())) { markLive(stmt); } - // Note that StmtSort::getStmts() is also run in traverseTo. In the future, + // Note that StmtSort::getStmtsTo() is also run in traverseTo. In the future, // we could potentially refactor this so that derived classes from // BackwardVisitor can make use of that traversal instead of repeating it. traverseTo(fusion_, fusion_->outputs(), false); diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index a16a23e74ac..8222cdf8839 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -311,7 +311,7 @@ class StmtSort : public IterVisitor { bool traverse_siblings = false); // Returns ordered Statements required to produce 'to', including 'to'. - static std::vector getStmts( + static std::vector getStmtsTo( Fusion* fusion, const std::vector& to, bool traverse_members = false, @@ -351,7 +351,7 @@ class StmtSort : public IterVisitor { bool traverse_siblings = false); // Same as getStmts version but filters to only return the Expr*s - static std::vector getExprs( + static std::vector getExprsTo( Fusion* fusion, const std::vector& to, bool traverse_members = false, diff --git a/csrc/partial_split_map.cpp b/csrc/partial_split_map.cpp index c46b258fe3a..9540c6b08a2 100644 --- a/csrc/partial_split_map.cpp +++ b/csrc/partial_split_map.cpp @@ -15,7 +15,7 @@ void PartialSplitMap::build(Fusion* fusion) { auto used_vals = ir_utils::allTvs(fusion); for (auto tv : ir_utils::filterByType(used_vals)) { - auto exprs = StmtSort::getExprs( + auto exprs = StmtSort::getExprsTo( fusion, {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // Only needs to check root domains as partial split is only diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index b13b633df4c..cce61fd9b0c 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1057,8 +1057,7 @@ IterDomain* projectIdToRoot( return reference_id; } - auto replay_exprs = - StmtSort::getExprs(tv->fusion(), {reference_id}, false, false, false); + auto replay_exprs = StmtSort::getExprsTo(tv->fusion(), {reference_id}); if (replay_exprs.empty()) { return reference_id; } @@ -1118,7 +1117,7 @@ IterDomain* projectIdToRFactor( return reference_id; } - auto replay_exprs = StmtSort::getExprs( + auto replay_exprs = StmtSort::getExprsTo( tv->fusion(), {tv->getRFactorDomain().begin(), tv->getRFactorDomain().end()}, false); @@ -1786,7 +1785,7 @@ DisjointSets disjointRFactorSets(Fusion* fusion) { // If iter domains are involved in any transformation from root domains to // rfactor domains they should be considered "contaminated". for (auto tv : ir_utils::allTvs(fusion)) { - for (auto expr : StmtSort::getExprs( + for (auto expr : StmtSort::getExprsTo( fusion, {tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()})) { @@ -1837,7 +1836,7 @@ bool breakIsDisjoint(std::vector group_ids, int pos) { std::unordered_map domainReorderAsRfactorMap(TensorView* tv) { FusionGuard fg(tv->fusion()); - auto transform_exprs = StmtSort::getExprs( + auto transform_exprs = StmtSort::getExprsTo( tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); // simply update this vector of id's as progressing through the transformation // expressions. We'll always insert the result of split in the location of the diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index cbae774fe52..14a64f6e204 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -322,7 +322,7 @@ BestEffortReplay::BestEffortReplay( } // Grab expr history of iter domains in target_domain - std::vector target_exprs = StmtSort::getExprs( + std::vector target_exprs = StmtSort::getExprsTo( FusionGuard::getCurFusion(), std::vector(target_domain.begin(), target_domain.end())); @@ -333,7 +333,7 @@ BestEffortReplay::BestEffortReplay( // replay_domain map. // Map replay domain's IterDomains to the Exprs they're used in - std::vector replay_exprs = StmtSort::getExprs( + std::vector replay_exprs = StmtSort::getExprsTo( FusionGuard::getCurFusion(), std::vector(replay_domain.begin(), replay_domain.end())); @@ -752,7 +752,7 @@ struct ForwardingInfo { // We have root axes in active_tv that don't exist in the inactive tensor, // now forward those to include all id's in active_tv comprised of only axes // not in the inactive tensor. - std::vector active_tv_history = StmtSort::getExprs( + std::vector active_tv_history = StmtSort::getExprsTo( FusionGuard::getCurFusion(), std::vector( active_tv->getLeafDomain().begin(), @@ -903,7 +903,7 @@ void BestEffortReplay::addComplimentLeafIDs( } // Grab all exprs used to make the forwarded compliments - auto compliment_exprs = StmtSort::getExprs( + auto compliment_exprs = StmtSort::getExprsTo( FusionGuard::getCurFusion(), {compliments.begin(), compliments.end()}); // Figure out if there are any leaves in compliment_exprs that aren't From d19575da023a9c58f5955803ca26217284f344c8 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 17 Jul 2023 15:18:15 -0400 Subject: [PATCH 6/9] Handle orphaned siblings at end of loop This is still a topological ordering since these siblings have no active uses (by definition), and this method prevents changing the traversal ordering by changing this flag when there are no orphaned siblings in the Fusion. --- csrc/iter_visitor.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index d43c93ff411..79dc82dbe49 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -148,6 +148,7 @@ void IterVisitor::traverseBetween( std::unordered_set visited; std::unordered_set nodes_on_path; + std::vector maybe_orphaned_sibs; stmt_stack.clear(); stmt_stack.emplace_back(to.rbegin(), to.rend()); @@ -231,18 +232,13 @@ void IterVisitor::traverseBetween( if (traverse_siblings) { // Add unvisited siblings to next_stmts - std::vector unvisited_sibs; for (auto next_val : ir_utils::filterByType(next_stmts)) { for (auto sib : ir_utils::siblingValsOf(next_val)) { - if (visited.find(sib) == visited.end()) { - // Push to separate vector so that we don't modify next_stmts - // while looping - unvisited_sibs.push_back(sib); + if (traverse_all_paths || visited.find(sib) == visited.end()) { + maybe_orphaned_sibs.push_back(sib); } } } - next_stmts.insert( - next_stmts.end(), unvisited_sibs.begin(), unvisited_sibs.end()); } if (next_stmts.empty()) { @@ -274,6 +270,14 @@ void IterVisitor::traverseBetween( } } } + // Handle any sibling Vals that have not yet been handled + // If traverse_siblings is false, this vector will be empty + for (auto val : maybe_orphaned_sibs) { + if (visited.find(val) != visited.end()) { + visited.insert(val); + handle(val); + } + } } void IterVisitor::traverseTo( From f12856d7ee05adf1605b67ca2f138c2b58d384b7 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 18 Jul 2023 10:09:17 -0400 Subject: [PATCH 7/9] Fix typo in traverseBetween --- csrc/iter_visitor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 79dc82dbe49..22b03c63f44 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -273,7 +273,7 @@ void IterVisitor::traverseBetween( // Handle any sibling Vals that have not yet been handled // If traverse_siblings is false, this vector will be empty for (auto val : maybe_orphaned_sibs) { - if (visited.find(val) != visited.end()) { + if (visited.find(val) == visited.end()) { visited.insert(val); handle(val); } From 48157a187796639a6bd1d378af22a1a6bf45b7e2 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 18 Jul 2023 10:09:27 -0400 Subject: [PATCH 8/9] Add IterVisitorTraverseSiblings_CUDA test --- test/test_gpu3.cpp | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 25124901093..68e3a6b8202 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -9301,6 +9301,35 @@ TEST_F(NVFuserTest, FusionDanglingUnaryOp_CUDA) { __FILE__); } +// Test that traversing siblings with IterVisitor visits "orphans", i.e. unused +// outputs of multi-output Exprs. +TEST_F(NVFuserTest, IterVisitorTraverseSiblings_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto wf = Welford(tv0, {0}); + // wf.var_sum is used, but wf.avg and wf.n are orphaned + auto tv1 = neg(wf.var_sum); + fusion.addOutput(tv1); + + auto stmts = StmtSort::getStmts( + &fusion, + /*traverse_all_paths*/ false, + /*traverse_attributes*/ false, + /*traverse_siblings*/ true); + + // Make sure the expansion parameters of tv1_resize are visited + TORCH_CHECK( + std::find(stmts.begin(), stmts.end(), wf.avg) != stmts.end(), + "Welford avg not traversed"); + TORCH_CHECK( + std::find(stmts.begin(), stmts.end(), wf.n) != stmts.end(), + "Welford n not traversed"); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser From 0cd1b253fb39ee996549e95ccaf95b3df246f18a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 18 Jul 2023 10:17:23 -0400 Subject: [PATCH 9/9] Fix bug where siblings of outputs were not added --- csrc/iter_visitor.cpp | 10 ++++++++++ test/test_gpu3.cpp | 15 +++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 22b03c63f44..61e19c063c2 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -153,6 +153,16 @@ void IterVisitor::traverseBetween( stmt_stack.clear(); stmt_stack.emplace_back(to.rbegin(), to.rend()); + if (traverse_siblings) { + // Append siblings of entries in "to" to bottom of stack + auto& bottom_stack = stmt_stack.back(); + for (auto val : ir_utils::filterByType(bottom_stack)) { + for (auto sib : ir_utils::siblingValsOf(val)) { + maybe_orphaned_sibs.push_back(sib); + } + } + } + bool all_inputs_visited = false; while (!stmt_stack.empty()) { diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 68e3a6b8202..1ced886b8ff 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -9328,6 +9328,21 @@ TEST_F(NVFuserTest, IterVisitorTraverseSiblings_CUDA) { TORCH_CHECK( std::find(stmts.begin(), stmts.end(), wf.n) != stmts.end(), "Welford n not traversed"); + + // Test getting statements "to" a tensor with siblings + stmts = StmtSort::getStmtsTo( + &fusion, + {wf.n}, + /*traverse_all_paths*/ false, + /*traverse_attributes*/ false, + /*traverse_siblings*/ true); + // Make sure the expansion parameters of tv1_resize are visited + TORCH_CHECK( + std::find(stmts.begin(), stmts.end(), wf.avg) != stmts.end(), + "Welford avg not traversed in getStmtsTo({n})"); + TORCH_CHECK( + std::find(stmts.begin(), stmts.end(), wf.var_sum) != stmts.end(), + "Welford var_sum not traversed in getStmtsTo({n})"); } // Test file size should be up to 10K LoC. Create a new file for more tests.