diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index 6ca15f84c7f..3cf5c67733f 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -605,7 +605,7 @@ void IterDomainGraph::build(Fusion* fusion) { // Grab all the rfactor ids. for (auto consumer_tv : all_consumer_tvs) { - auto exprs = StmtSort::getExprs( + auto exprs = StmtSort::getExprsTo( fusion, {consumer_tv->getMaybeRFactorDomain().begin(), consumer_tv->getMaybeRFactorDomain().end()}); diff --git a/csrc/device_lower/analysis/divisible_split.cpp b/csrc/device_lower/analysis/divisible_split.cpp index 875df3e3414..75d344a4dbd 100644 --- a/csrc/device_lower/analysis/divisible_split.cpp +++ b/csrc/device_lower/analysis/divisible_split.cpp @@ -38,7 +38,7 @@ std::unordered_set getAllDivisibleSplits( // Take the view transformations and add all the splits. Those splits are // the only divisible splits. auto view_exprs = - StmtSort::getExprs(fusion, {rfactor_dom.begin(), rfactor_dom.end()}); + StmtSort::getExprsTo(fusion, {rfactor_dom.begin(), rfactor_dom.end()}); auto split_exprs = ir_utils::filterByType(view_exprs); all_divisible_splits.insert(split_exprs.begin(), split_exprs.end()); } diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 32b72caf607..efcb9c8e82f 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -216,7 +216,7 @@ class AllocationInserter : public kir::ExprMutator { [](IterDomain* dom) { return dom->as(); }); // Get all exprs involved in generating the allocation IDs - auto exprs = StmtSort::getExprs(tv->fusion(), start_vals); + auto exprs = StmtSort::getExprsTo(tv->fusion(), start_vals); // Get the halo extent if found auto getExtent = [this](IterDomain* id) { diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index c2c2250d0b3..1dc550d0b0c 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -828,7 +828,7 @@ void validatePartialSplit(Fusion* fusion) { auto range_info = getLiveRangeOffsets(fusion); for (auto tv : ir_utils::allTvs(fusion)) { - auto exprs = StmtSort::getExprs( + auto exprs = StmtSort::getExprsTo( tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // When the start and stop offsets are not zero, make sure the diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 6ce6f592e9d..f9963b059f4 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -452,7 +452,11 @@ void DynamicTransformConcretizer::concretize() { concretizeEmptyExtents(); // Finally, propagate concretized domains - auto all_stmts = StmtSort::getStmts(info_->fusion()); + auto all_stmts = StmtSort::getStmts( + info_->fusion(), + /*traverse_members*/ false, + /*traverse_attributes*/ false, + /*traverse_siblings*/ true); for (auto tv : ir_utils::filterByType(all_stmts)) { mutate(tv); } diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 8fa3605d498..2affce69e7b 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -490,7 +490,7 @@ void Fusion::printMath(bool from_outputs_only) { leaf_vals.push_back(val); } } - exprs_for_print = StmtSort::getExprs(this, leaf_vals); + exprs_for_print = StmtSort::getExprsTo(this, leaf_vals); } debug() << "\n%kernel_math {\n"; diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 3b4458ea8f9..d8687968657 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -3600,7 +3600,7 @@ void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) { group->input_vals = IterVisitor::getInputsTo(group->inputs()); // Grab all expressions needed to produce to_visit - auto input_exprs = StmtSort::getExprs(completeFusion(), to_visit); + auto input_exprs = StmtSort::getExprsTo(completeFusion(), to_visit); // Insert those expressions at the beginning of the group group->exprs_.insert( diff --git a/csrc/ir/cloner.cpp b/csrc/ir/cloner.cpp index 66295857303..61082a6391a 100644 --- a/csrc/ir/cloner.cpp +++ b/csrc/ir/cloner.cpp @@ -116,7 +116,7 @@ Statement* RecomputeTv::handle(const TensorDomain* td) { // Make sure to recompute the history of the iteration domains, explicitly go // through the expressions and send them to IrCloner. auto exprs = - StmtSort::getExprs(fusion_, {td->leaf().begin(), td->leaf().end()}); + StmtSort::getExprsTo(fusion_, {td->leaf().begin(), td->leaf().end()}); for (auto expr : exprs) { IrCloner::handle(expr); diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 8f3be4f934c..71287011070 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -463,7 +463,7 @@ class ValReplacementMutator : private OptOutMutator { // typically not used by anything else. If we don't grab that count, then it // would be a tensorview that doesn't get updated extents. Therefore, first // grab all leaves towards outputs and grab stmts from there. - auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true, true); + auto stmts = StmtSort::getStmtsTo(fusion, allLeafOuts(fusion), true, true); // Some fusions, such as standalone rand_like, can have disconnected DAG, so // we need some mechanism to make sure our replacement set is as complete as @@ -481,7 +481,7 @@ class ValReplacementMutator : private OptOutMutator { more.emplace_back(v); } } - auto more_stmts = StmtSort::getStmts(fusion, more, true, true); + auto more_stmts = StmtSort::getStmtsTo(fusion, more, true, true); more_stmts.insert(more_stmts.end(), stmts.begin(), stmts.end()); for (auto stmt : more_stmts) { diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 932c876a214..7569d7a9f2b 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -142,15 +142,27 @@ void IterVisitor::traverseBetween( const std::vector& to, bool traverse_all_paths, bool traverse_into_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { FusionGuard fg(fusion); std::unordered_set visited; std::unordered_set nodes_on_path; + std::vector maybe_orphaned_sibs; stmt_stack.clear(); stmt_stack.emplace_back(to.rbegin(), to.rend()); + if (traverse_siblings) { + // Append siblings of entries in "to" to bottom of stack + auto& bottom_stack = stmt_stack.back(); + for (auto val : ir_utils::filterByType(bottom_stack)) { + for (auto sib : ir_utils::siblingValsOf(val)) { + maybe_orphaned_sibs.push_back(sib); + } + } + } + bool all_inputs_visited = false; while (!stmt_stack.empty()) { @@ -222,6 +234,18 @@ void IterVisitor::traverseBetween( // If we don't want to retraverse, remove nodes we already visisted. remove_visited(next_stmts, visited); } + + if (traverse_siblings) { + // Add unvisited siblings to next_stmts + for (auto next_val : ir_utils::filterByType(next_stmts)) { + for (auto sib : ir_utils::siblingValsOf(next_val)) { + if (traverse_all_paths || visited.find(sib) == visited.end()) { + maybe_orphaned_sibs.push_back(sib); + } + } + } + } + if (next_stmts.empty()) { // If there's nothing to visit because it was all already visited, mark // to process @@ -251,6 +275,14 @@ void IterVisitor::traverseBetween( } } } + // Handle any sibling Vals that have not yet been handled + // If traverse_siblings is false, this vector will be empty + for (auto val : maybe_orphaned_sibs) { + if (visited.find(val) == visited.end()) { + visited.insert(val); + handle(val); + } + } } void IterVisitor::traverseTo( @@ -258,14 +290,16 @@ void IterVisitor::traverseTo( const std::vector& to, bool traverse_all_paths, bool traverse_into_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { traverseBetween( fusion, {}, to, traverse_all_paths, traverse_into_members, - traverse_attributes); + traverse_attributes, + traverse_siblings); } void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) { @@ -430,7 +464,7 @@ void BackwardVisitor::traverseTo( } auto vals = AllVals::get(fusion, from); - auto exprs = StmtSort::getExprs(fusion, from); + auto exprs = StmtSort::getExprsTo(fusion, from); { size_t pos = 0; @@ -841,19 +875,25 @@ void StmtSort::handle(Statement* stmt) { std::vector StmtSort::getExprs( Fusion* fusion, bool traverse_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { auto terminating_outputs = fusion->getTerminatingOutputs(); - return StmtSort::getExprs( - fusion, terminating_outputs, traverse_members, traverse_attributes); + return StmtSort::getExprsTo( + fusion, + terminating_outputs, + traverse_members, + traverse_attributes, + traverse_siblings); } -std::vector StmtSort::getExprs( +std::vector StmtSort::getExprsTo( Fusion* fusion, const std::vector& to, bool traverse_members, - bool traverse_attributes) { - auto stmts = - StmtSort::getStmts(fusion, to, traverse_members, traverse_attributes); + bool traverse_attributes, + bool traverse_siblings) { + auto stmts = StmtSort::getStmtsTo( + fusion, to, traverse_members, traverse_attributes, traverse_siblings); auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); std::vector exprs(filter.begin(), filter.end()); return exprs; @@ -864,9 +904,15 @@ std::vector StmtSort::getExprsBetween( const std::vector& from, const std::vector& to, bool traverse_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { auto stmts = StmtSort::getStmtsBetween( - fusion, from, to, traverse_members, traverse_attributes); + fusion, + from, + to, + traverse_members, + traverse_attributes, + traverse_siblings); auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); std::vector exprs(filter.begin(), filter.end()); return exprs; @@ -875,19 +921,31 @@ std::vector StmtSort::getExprsBetween( std::vector StmtSort::getStmts( Fusion* fusion, bool traverse_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { auto terminating_outputs = fusion->getTerminatingOutputs(); - return StmtSort::getStmts( - fusion, terminating_outputs, traverse_members, traverse_attributes); + return StmtSort::getStmtsTo( + fusion, + terminating_outputs, + traverse_members, + traverse_attributes, + traverse_siblings); } -std::vector StmtSort::getStmts( +std::vector StmtSort::getStmtsTo( Fusion* fusion, const std::vector& to, bool traverse_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { StmtSort es; - es.traverseTo(fusion, to, false, traverse_members, traverse_attributes); + es.traverseTo( + fusion, + to, + false, + traverse_members, + traverse_attributes, + traverse_siblings); return es.stmts; } @@ -896,7 +954,8 @@ std::vector StmtSort::getStmtsBetween( const std::vector& from, const std::vector& to, bool traverse_members, - bool traverse_attributes) { + bool traverse_attributes, + bool traverse_siblings) { StmtSort es; es.traverseBetween( fusion, @@ -904,7 +963,8 @@ std::vector StmtSort::getStmtsBetween( to, false, traverse_members, - traverse_attributes); + traverse_attributes, + traverse_siblings); return es.stmts; } @@ -932,11 +992,11 @@ std::vector InputsOf::outputs( bool DeadCodeRemover::run() { // First we build a set of all live Statements so that we can detect dead // branches. - for (auto stmt : StmtSort::getStmts(fusion_, fusion_->outputs())) { + for (auto stmt : StmtSort::getStmtsTo(fusion_, fusion_->outputs())) { markLive(stmt); } - // Note that StmtSort::getStmts() is also run in traverseTo. In the future, + // Note that StmtSort::getStmtsTo() is also run in traverseTo. In the future, // we could potentially refactor this so that derived classes from // BackwardVisitor can make use of that traversal instead of repeating it. traverseTo(fusion_, fusion_->outputs(), false); diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index 425463cbc36..8222cdf8839 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -94,12 +94,16 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { //! \param traverse_attributes When true, traverse into expr //! attributes. Note that attributes of template type Attribute are //! not traversed as there's no dispatch support. + //! \param traverse_siblings When true, traverse all outputs of + //! active multi-output expressions, even if those Expr outputs are not used + //! in paths to Fusion outputs. void traverseTo( Fusion* fusion, const std::vector& to, bool traverse_all_paths = false, bool traverse_into_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); //! Traverses nodes in Fusion from inputs in topological order to "to". i.e. //! from inputs towards outputs. @@ -117,13 +121,17 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { //! \param traverse_attributes When true, traverse into expr //! attributes. Note that attributes of template type Attribute are //! not traversed as there's no dispatch support. + //! \param traverse_siblings When true, traverse all outputs of + //! active multi-output expressions, even if those Expr outputs are not used + //! in paths to Fusion outputs. void traverseBetween( Fusion* fusion, const std::unordered_set& from, const std::vector& to, bool traverse_all_paths = false, bool traverse_into_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Iterates from terminating outputs registered with the fusion. Terminating // means value is not used to generate any other value used in producing @@ -299,14 +307,16 @@ class StmtSort : public IterVisitor { static std::vector getStmts( Fusion* fusion, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Returns ordered Statements required to produce 'to', including 'to'. - static std::vector getStmts( + static std::vector getStmtsTo( Fusion* fusion, const std::vector& to, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Returns ordered Statements required to produce from, including from. // Stops traversal once hiting any Statements in to. Includes Statements in @@ -330,20 +340,23 @@ class StmtSort : public IterVisitor { const std::vector& from, const std::vector& to, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Same as getStmts version but filters to only return the Expr*s static std::vector getExprs( Fusion* fusion, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Same as getStmts version but filters to only return the Expr*s - static std::vector getExprs( + static std::vector getExprsTo( Fusion* fusion, const std::vector& to, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); // Same as getStmts version but filters to only return the Expr*s static std::vector getExprsBetween( @@ -351,7 +364,8 @@ class StmtSort : public IterVisitor { const std::vector& from, const std::vector& to, bool traverse_members = false, - bool traverse_attributes = false); + bool traverse_attributes = false, + bool traverse_siblings = false); }; class TORCH_CUDA_CU_API InputsOf : public IterVisitor { diff --git a/csrc/partial_split_map.cpp b/csrc/partial_split_map.cpp index c46b258fe3a..9540c6b08a2 100644 --- a/csrc/partial_split_map.cpp +++ b/csrc/partial_split_map.cpp @@ -15,7 +15,7 @@ void PartialSplitMap::build(Fusion* fusion) { auto used_vals = ir_utils::allTvs(fusion); for (auto tv : ir_utils::filterByType(used_vals)) { - auto exprs = StmtSort::getExprs( + auto exprs = StmtSort::getExprsTo( fusion, {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // Only needs to check root domains as partial split is only diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 89d79ea91d4..cce61fd9b0c 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1057,8 +1057,7 @@ IterDomain* projectIdToRoot( return reference_id; } - auto replay_exprs = - StmtSort::getExprs(tv->fusion(), {reference_id}, false, false); + auto replay_exprs = StmtSort::getExprsTo(tv->fusion(), {reference_id}); if (replay_exprs.empty()) { return reference_id; } @@ -1118,7 +1117,7 @@ IterDomain* projectIdToRFactor( return reference_id; } - auto replay_exprs = StmtSort::getExprs( + auto replay_exprs = StmtSort::getExprsTo( tv->fusion(), {tv->getRFactorDomain().begin(), tv->getRFactorDomain().end()}, false); @@ -1786,7 +1785,7 @@ DisjointSets disjointRFactorSets(Fusion* fusion) { // If iter domains are involved in any transformation from root domains to // rfactor domains they should be considered "contaminated". for (auto tv : ir_utils::allTvs(fusion)) { - for (auto expr : StmtSort::getExprs( + for (auto expr : StmtSort::getExprsTo( fusion, {tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()})) { @@ -1837,7 +1836,7 @@ bool breakIsDisjoint(std::vector group_ids, int pos) { std::unordered_map domainReorderAsRfactorMap(TensorView* tv) { FusionGuard fg(tv->fusion()); - auto transform_exprs = StmtSort::getExprs( + auto transform_exprs = StmtSort::getExprsTo( tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); // simply update this vector of id's as progressing through the transformation // expressions. We'll always insert the result of split in the location of the diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index cbae774fe52..14a64f6e204 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -322,7 +322,7 @@ BestEffortReplay::BestEffortReplay( } // Grab expr history of iter domains in target_domain - std::vector target_exprs = StmtSort::getExprs( + std::vector target_exprs = StmtSort::getExprsTo( FusionGuard::getCurFusion(), std::vector(target_domain.begin(), target_domain.end())); @@ -333,7 +333,7 @@ BestEffortReplay::BestEffortReplay( // replay_domain map. // Map replay domain's IterDomains to the Exprs they're used in - std::vector replay_exprs = StmtSort::getExprs( + std::vector replay_exprs = StmtSort::getExprsTo( FusionGuard::getCurFusion(), std::vector(replay_domain.begin(), replay_domain.end())); @@ -752,7 +752,7 @@ struct ForwardingInfo { // We have root axes in active_tv that don't exist in the inactive tensor, // now forward those to include all id's in active_tv comprised of only axes // not in the inactive tensor. - std::vector active_tv_history = StmtSort::getExprs( + std::vector active_tv_history = StmtSort::getExprsTo( FusionGuard::getCurFusion(), std::vector( active_tv->getLeafDomain().begin(), @@ -903,7 +903,7 @@ void BestEffortReplay::addComplimentLeafIDs( } // Grab all exprs used to make the forwarded compliments - auto compliment_exprs = StmtSort::getExprs( + auto compliment_exprs = StmtSort::getExprsTo( FusionGuard::getCurFusion(), {compliments.begin(), compliments.end()}); // Figure out if there are any leaves in compliment_exprs that aren't diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 01ac2c7bde4..4dde5f84e43 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -1068,6 +1068,51 @@ TEST_F(NVFuserTest, FusionDynamicEmptyCat2_CUDA) { EXPECT_EQ(output_def->input(0), seg_fusion->inputs()[0]); } +// Repro of https://github.com/NVIDIA/Fuser/issues/418 +TEST_F(NVFuserTest, DynamicTransformIssue418Concretization_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(4); + fusion->addInput(tv0); + auto s0 = IrBuilder::create(DataType::Int); + fusion->addInput(s0); + + auto v00 = tv0->axis(0)->extent(); + auto v01 = tv0->axis(1)->extent(); + auto v02 = tv0->axis(2)->extent(); + auto v03 = tv0->axis(3)->extent(); + + auto tv1 = reshape(tv0, {v00, div(v01, s0), s0, v02, v03}); + auto vm = variance_mean(tv1, {2, 3, 4}, 0, true); + fusion->addOutput(vm.mean); + fusion->addOutput(vm.var); + + { + ExpressionEvaluator expr_eval; + + expr_eval.bind(tv0->axis(0)->extent(), 256L); + expr_eval.bind(tv0->axis(1)->extent(), 128L); + expr_eval.bind(tv0->axis(2)->extent(), 28L); + expr_eval.bind(tv0->axis(3)->extent(), 28L); + expr_eval.bind(s0, 4L); + + auto initial_info = DynamicTransform::getInitialInfo(fusion.get()); + auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); + + TORCH_CHECK( + info.getReshapeTransforms().size() == 1, + "Expected to have one reshape transform: ", + info.toString()); + + DynamicTransform::concretizeFusion(fusion.get(), &info); + + TORCH_CHECK( + !fusion->hasDynamicTransform(), + "Expected to have no dynamic transform"); + } +} + TEST_F(NVFuserTest, Issue249_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index bd5bf6277ee..4e58835b43c 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -9303,6 +9303,50 @@ TEST_F(NVFuserTest, FusionDanglingUnaryOp_CUDA) { __FILE__); } +// Test that traversing siblings with IterVisitor visits "orphans", i.e. unused +// outputs of multi-output Exprs. +TEST_F(NVFuserTest, IterVisitorTraverseSiblings_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto wf = Welford(tv0, {0}); + // wf.var_sum is used, but wf.avg and wf.n are orphaned + auto tv1 = neg(wf.var_sum); + fusion.addOutput(tv1); + + auto stmts = StmtSort::getStmts( + &fusion, + /*traverse_all_paths*/ false, + /*traverse_attributes*/ false, + /*traverse_siblings*/ true); + + // Make sure the expansion parameters of tv1_resize are visited + TORCH_CHECK( + std::find(stmts.begin(), stmts.end(), wf.avg) != stmts.end(), + "Welford avg not traversed"); + TORCH_CHECK( + std::find(stmts.begin(), stmts.end(), wf.n) != stmts.end(), + "Welford n not traversed"); + + // Test getting statements "to" a tensor with siblings + stmts = StmtSort::getStmtsTo( + &fusion, + {wf.n}, + /*traverse_all_paths*/ false, + /*traverse_attributes*/ false, + /*traverse_siblings*/ true); + // Make sure the expansion parameters of tv1_resize are visited + TORCH_CHECK( + std::find(stmts.begin(), stmts.end(), wf.avg) != stmts.end(), + "Welford avg not traversed in getStmtsTo({n})"); + TORCH_CHECK( + std::find(stmts.begin(), stmts.end(), wf.var_sum) != stmts.end(), + "Welford var_sum not traversed in getStmtsTo({n})"); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser