From 2049e901565118412ca7c17f5540fbd00f2110bd Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sat, 2 Dec 2023 18:11:42 -0800 Subject: [PATCH] Remove the unnecessary fusion parameter. Partially rolls forward #1413. --- csrc/compute_at_map.cpp | 1 - csrc/contiguity.cpp | 12 +--- .../device_lower/analysis/divisible_split.cpp | 2 +- .../analysis/predicate_elimination.cpp | 2 +- .../analysis/sync_information.cpp | 4 +- .../analysis/thread_predicate.cpp | 2 +- csrc/device_lower/pass/alias_memory.cpp | 3 +- csrc/device_lower/pass/allocation.cpp | 2 +- csrc/device_lower/pass/expr_sort.cpp | 4 +- csrc/device_lower/pass/warp_reduce.cpp | 3 +- csrc/device_lower/validation.cpp | 3 +- csrc/dynamic_transform.cpp | 5 +- csrc/executor.cpp | 22 ++---- csrc/fusion.cpp | 6 +- csrc/fusion_segmenter.cpp | 4 +- csrc/index_compute.cpp | 4 +- csrc/ir/cloner.cpp | 8 +-- csrc/ir/cloner.h | 2 - csrc/ir/iostream.cpp | 1 + csrc/ir/nodes.cpp | 2 +- csrc/ir/utils.cpp | 6 +- csrc/iter_visitor.cpp | 71 +++++++------------ csrc/iter_visitor.h | 17 +---- csrc/multidevice/executor.cpp | 2 +- csrc/multidevice/pipeline.cpp | 2 +- csrc/non_divisible_split.cpp | 2 +- csrc/partial_split_map.cpp | 2 +- csrc/root_domain_map.cpp | 6 +- csrc/scheduler/transpose.cpp | 6 +- csrc/scheduler/utils.cpp | 10 +-- csrc/scheduler/vectorize_helper.cpp | 8 +-- csrc/tensor_metadata.cpp | 8 +-- csrc/tensor_view.cpp | 2 +- csrc/transform_iter.cpp | 21 +++--- test/test_gpu3.cpp | 2 +- test/test_iter_visitor.cpp | 4 +- test/test_optimization_pass.cpp | 6 +- test/test_swizzle.cpp | 1 - 38 files changed, 93 insertions(+), 175 deletions(-) diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index e457d8923a4..f6a7755fdc8 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -605,7 +605,6 @@ void IterDomainGraph::build(Fusion* fusion) { // Grab all the rfactor ids. for (auto consumer_tv : all_consumer_tvs) { auto exprs = StmtSort::getExprsTo( - fusion, {consumer_tv->getMaybeRFactorDomain().begin(), consumer_tv->getMaybeRFactorDomain().end()}); for (auto expr : exprs) { diff --git a/csrc/contiguity.cpp b/csrc/contiguity.cpp index a1df7807845..c41cf793644 100644 --- a/csrc/contiguity.cpp +++ b/csrc/contiguity.cpp @@ -39,9 +39,7 @@ OrderedIdInformation::OrderedIdInformation( // consistently_ordered_ids_, id_to_alloc_ids_, and // exclusively_consumes_allocs_ for all the IDs auto exprs = StmtSort::getExprsBetween( - ids[0]->fusion(), - {alloc_domain.begin(), alloc_domain.end()}, - {ids.begin(), ids.end()}); + {alloc_domain.begin(), alloc_domain.end()}, {ids.begin(), ids.end()}); for (auto expr : exprs) { OptInDispatch::dispatch(expr); @@ -386,9 +384,7 @@ NonDivisibleSplitDependencies::NonDivisibleSplitDependencies( return; } auto transforms = StmtSort::getExprsBetween( - ids[0]->fusion(), - {alloc_domain.begin(), alloc_domain.end()}, - {ids.begin(), ids.end()}); + {alloc_domain.begin(), alloc_domain.end()}, {ids.begin(), ids.end()}); for (auto transform : transforms) { auto inp_ids = ir_utils::filterByType(transform->inputs()); for (auto inp_id : inp_ids) { @@ -545,9 +541,7 @@ void ContigIDs::build(const std::vector& ids) { if (!contig_ids_.empty()) { auto exprs = StmtSort::getExprsBetween( - ids.at(0)->fusion(), - {alloc_domain_.begin(), alloc_domain_.end()}, - {ids.begin(), ids.end()}); + {alloc_domain_.begin(), alloc_domain_.end()}, {ids.begin(), ids.end()}); for (auto expr : exprs) { if (auto resize = dynamic_cast(expr)) { resize_deps_.insert(resize->out()); diff --git a/csrc/device_lower/analysis/divisible_split.cpp b/csrc/device_lower/analysis/divisible_split.cpp index 75d344a4dbd..5b62bf5de2d 100644 --- a/csrc/device_lower/analysis/divisible_split.cpp +++ b/csrc/device_lower/analysis/divisible_split.cpp @@ -38,7 +38,7 @@ std::unordered_set getAllDivisibleSplits( // Take the view transformations and add all the splits. Those splits are // the only divisible splits. auto view_exprs = - StmtSort::getExprsTo(fusion, {rfactor_dom.begin(), rfactor_dom.end()}); + StmtSort::getExprsTo({rfactor_dom.begin(), rfactor_dom.end()}); auto split_exprs = ir_utils::filterByType(view_exprs); all_divisible_splits.insert(split_exprs.begin(), split_exprs.end()); } diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 676c5f7e695..e3078b326d4 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -863,7 +863,7 @@ class PredicateChcker : public IterVisitor { } // namespace PredicateElimination::PredicateElimination(Fusion* fusion) { - traverseTo(fusion, fusion->outputs()); + traverseTo(fusion->outputs()); } bool PredicateElimination::needsPredicate(Expr* expr) const { diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index f7c922ad5e8..36a5c891314 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -106,7 +106,6 @@ struct ProducerConsumerIndexingInfoCache { const auto& consumer_leaf_ids_shared_with_producer = getConsumerLeafIDsSharedWithProducer(); consumer_root_ids_shared_with_producer_ = InputsOf::outputs( - producer_tv_->fusion(), {consumer_leaf_ids_shared_with_producer.begin(), consumer_leaf_ids_shared_with_producer.end()}); } @@ -261,10 +260,9 @@ bool useSameIndex( // consumer_id. The goal of the analysis below is to find out if all // of the root IDs are indexed in the same way between the producer // and consumer tensors. - auto consumer_root_ids = InputsOf::output(consumer_id->fusion(), consumer_id); + auto consumer_root_ids = InputsOf::output(consumer_id); auto producer_root_vals = StmtSort::getStmtsBetween( - producer_id->fusion(), {producer_tv->getMaybeRFactorDomain().begin(), producer_tv->getMaybeRFactorDomain().end()}, {producer_id}); diff --git a/csrc/device_lower/analysis/thread_predicate.cpp b/csrc/device_lower/analysis/thread_predicate.cpp index 98cf4886297..fdb0b0806aa 100644 --- a/csrc/device_lower/analysis/thread_predicate.cpp +++ b/csrc/device_lower/analysis/thread_predicate.cpp @@ -366,7 +366,7 @@ class RedundantUseAnalysis : BackwardVisitor { public: RedundantUseAnalysis(Fusion* fusion, const ThreadPredicateMap& pred_map) : fusion_(fusion), pred_map_(pred_map) { - traverseTo(fusion, fusion->terminatingMathVals()); + traverseTo(fusion->terminatingMathVals()); } //! Returns a bit map signifying the parallel dimensions diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 080967fb1aa..8e59da72dd6 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -122,8 +122,7 @@ bool isSerialBroadcastResolution( // traverse across view boundaries as we do in indexing. This // should not result in false aliasing but may miss safe aliasing // opportunities. - auto serial_loop_roots = - InputsOf::outputs(FusionGuard::getCurFusion(), serial_loop_concrete_ids); + auto serial_loop_roots = InputsOf::outputs(serial_loop_concrete_ids); // Collect exact concrete id's in producer's root domain std::unordered_set producer_exact_concrete_root_ids; diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 69e7cb69ce8..8c3052e2751 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -218,7 +218,7 @@ class AllocationInserter : public kir::ExprMutator { [](IterDomain* dom) { return dom->as(); }); // Get all exprs involved in generating the allocation IDs - auto exprs = StmtSort::getExprsTo(tv->fusion(), start_vals); + auto exprs = StmtSort::getExprsTo(start_vals); // Get the halo extent if found auto getExtent = [this](IterDomain* id) { diff --git a/csrc/device_lower/pass/expr_sort.cpp b/csrc/device_lower/pass/expr_sort.cpp index 6cb221cb3f5..52fa723fdff 100644 --- a/csrc/device_lower/pass/expr_sort.cpp +++ b/csrc/device_lower/pass/expr_sort.cpp @@ -1509,9 +1509,7 @@ void ExprSegmentationSorter::sort() { // Not putting the exprs between allKnownVals() and fusion inputs here // because they are computed using the expr evaluator. auto all_exprs = StmtSort::getExprsBetween( - fusion_, - GpuLower::current()->allKnownVals(), - fusion_->getTerminatingOutputs()); + GpuLower::current()->allKnownVals(), fusion_->getTerminatingOutputs()); // Figure out all the values used as inputs to the expressions we're sorting // (to find terminating expressions). There could be branches of expressions diff --git a/csrc/device_lower/pass/warp_reduce.cpp b/csrc/device_lower/pass/warp_reduce.cpp index c41e72a690e..8a27af271f7 100644 --- a/csrc/device_lower/pass/warp_reduce.cpp +++ b/csrc/device_lower/pass/warp_reduce.cpp @@ -104,8 +104,7 @@ class EliminateDeadBroadcastAndAllocate { // Also find any TVs used in index expressions. // These expressions will likely not be in the Expr tree we are // provided, so we need to traverse to find them. - auto all_index_roots = - InputsOf::outputs(FusionGuard::getCurFusion(), {ti->index()}); + auto all_index_roots = InputsOf::outputs({ti->index()}); auto index_root_tis = ir_utils::filterByType(all_index_roots); for (auto rootti : index_root_tis) { diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 9e2bc68e413..9e5df2bc26a 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -836,7 +836,7 @@ void validatePartialSplit(Fusion* fusion) { for (auto tv : ir_utils::allTvs(fusion)) { auto exprs = StmtSort::getExprsTo( - tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); + {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // When the start and stop offsets are not zero, make sure the // range defined by the split includes the required range to @@ -1276,7 +1276,6 @@ void validateResize(Fusion* fusion) { for (auto tv : ir_utils::filterByType(fusion_vals)) { // Make sure resize is only used as part of rfactor transformations auto rf_to_leaf_exprs = StmtSort::getExprsBetween( - fusion, {tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()}, {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index fb0c091632e..87550c6561c 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -90,7 +90,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { !fusion->isA(), "Invalid container. Kernel container not allowed.\n"); - traverseTo(fusion, fusion->getTerminatingOutputs(), false, false); + traverseTo(fusion->getTerminatingOutputs(), false, false); finalizeDynamicVals(); @@ -147,7 +147,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { //! Process vector of leaf dynamic values by finding inputs and recording the //! result into info_ void finalizeDynamicVals() { - const auto inputs = InputsOf::outputs(info_.fusion(), leaf_dynamic_vals_); + const auto inputs = InputsOf::outputs(leaf_dynamic_vals_); info_.root_dynamic_vals_.insert(inputs.begin(), inputs.end()); // initial_info_ provides a set of Vals that are used for concretization. @@ -621,7 +621,6 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Note that it is assumed that theres's no further expression // beyond the rfactor domain as asserted above auto all_id_exprs = StmtSort::getExprsBetween( - tv->fusion(), {tv->getRootDomain().begin(), tv->getRootDomain().end()}, {tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()}); diff --git a/csrc/executor.cpp b/csrc/executor.cpp index 0bec5acd3ae..990dd498491 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -275,7 +275,7 @@ void FusionExecutor::compileFusion( } output_extents.emplace_back(extent); } - auto dependencies = InputsOf::outputs(fusion, output_extents); + auto dependencies = InputsOf::outputs(output_extents); if (std::any_of(dependencies.begin(), dependencies.end(), [](Val* val) { return val->isFusionInput(); })) { @@ -607,7 +607,6 @@ std::pair, std::vector> inferShapeOfOutput( class ForwardTraverseFromAllocToRFactor { at::Tensor tensor_; - TensorView* tv_; ExpressionEvaluator& ee_; std::list& frontier_; @@ -725,18 +724,15 @@ class ForwardTraverseFromAllocToRFactor { public: ForwardTraverseFromAllocToRFactor( at::Tensor tensor, - TensorView* tv, ExpressionEvaluator& ee, std::list& frontier) - : tensor_(std::move(tensor)), tv_(tv), ee_(ee), frontier_(frontier) {} + : tensor_(std::move(tensor)), ee_(ee), frontier_(frontier) {} at::Tensor run( const std::vector& rfactor, const std::vector& alloc) { auto forward_exprs = StmtSort::getExprsBetween( - tv_->fusion(), - {alloc.begin(), alloc.end()}, - {rfactor.begin(), rfactor.end()}); + {alloc.begin(), alloc.end()}, {rfactor.begin(), rfactor.end()}); for (auto expr : forward_exprs) { handle(expr); } @@ -748,7 +744,6 @@ class ForwardTraverseFromAllocToRFactor { // transformations. class BackwardTraverseFromAllocToRFactor { at::Tensor tensor_; - TensorView* tv_; ExpressionEvaluator& ee_; std::list& frontier_; @@ -853,18 +848,15 @@ class BackwardTraverseFromAllocToRFactor { public: BackwardTraverseFromAllocToRFactor( at::Tensor tensor, - TensorView* tv, ExpressionEvaluator& ee, std::list& frontier) - : tensor_(std::move(tensor)), tv_(tv), ee_(ee), frontier_(frontier) {} + : tensor_(std::move(tensor)), ee_(ee), frontier_(frontier) {} at::Tensor run( const std::vector& rfactor, const std::vector& alloc) { auto backward_exprs = StmtSort::getExprsBetween( - tv_->fusion(), - {rfactor.begin(), rfactor.end()}, - {alloc.begin(), alloc.end()}); + {rfactor.begin(), rfactor.end()}, {alloc.begin(), alloc.end()}); std::reverse(backward_exprs.begin(), backward_exprs.end()); for (auto expr : backward_exprs) { handle(expr); @@ -894,9 +886,9 @@ at::Tensor transformOutputFromAllocationToRFactor( // forward and a backward traverse. std::list frontier(alloc.begin(), alloc.end()); NVF_ERROR(tensor.dim() == (int64_t)frontier.size()); - tensor = ForwardTraverseFromAllocToRFactor(tensor, tv, ee, frontier) + tensor = ForwardTraverseFromAllocToRFactor(tensor, ee, frontier) .run(rfactor, alloc); - tensor = BackwardTraverseFromAllocToRFactor(tensor, tv, ee, frontier) + tensor = BackwardTraverseFromAllocToRFactor(tensor, ee, frontier) .run(rfactor, alloc); NVF_ERROR(frontier.size() == rfactor.size()); // Now that all affine transformations are handled, and frontiers should diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index f719cd3dbd1..bea5bf0aa75 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -380,7 +380,7 @@ bool Fusion::isNoOp() { } std::vector Fusion::inputsOf(Val* val) { - return InputsOf::output(this, val); + return InputsOf::output(val); } void Fusion::validateInputs() { @@ -533,7 +533,7 @@ void Fusion::printMath(bool from_outputs_only) { leaf_vals.push_back(val); } } - exprs_for_print = StmtSort::getExprsTo(this, leaf_vals); + exprs_for_print = StmtSort::getExprsTo(leaf_vals); } debug() << "\n%kernel_math {\n"; @@ -654,7 +654,7 @@ std::vector Fusion::usedMathVals() { // there can be vals that are created inside a fusion without using // anything from inputs. See, for example, tv0 in the // FusionOuterSplit test. - const auto inputs = InputsOf::outputs(this, outputs()); + const auto inputs = InputsOf::outputs(outputs()); auto used_math_vals = DependencyCheck::getAllValsBetween( {inputs.begin(), inputs.end()}, outputs()); // When an expre has multiple outputs and only some of them are diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 06d57c28de6..8bacea97f75 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -3703,7 +3703,7 @@ void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) { group->input_vals = IterVisitor::getInputsTo(group->inputs()); // Grab all expressions needed to produce to_visit - auto input_exprs = StmtSort::getExprsTo(completeFusion(), to_visit); + auto input_exprs = StmtSort::getExprsTo(to_visit); // Insert those expressions at the beginning of the group group->exprs_.insert( @@ -3963,7 +3963,7 @@ class ForceHalfAnnotation : public IterVisitor { val->getDataType().value() == DataType::BFloat16); }); - annotation.traverseTo(fusion, fp16_outputs); + annotation.traverseTo(fp16_outputs); return annotation.force_fp16_tv_set_; } diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 7149e327a79..de1c4423686 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -912,7 +912,7 @@ void IndexCompute::updateIndexMapFromPermissiveMap(const Expr* id_expr) { void IndexCompute::run() { const std::vector domain_vals(td_->leaf().begin(), td_->leaf().end()); - traverseTo(td_->fusion(), domain_vals, false); + traverseTo(domain_vals, false); } IterDomain* IndexCompute::maybeGetExactMapConcreteID(IterDomain* id) const { @@ -1019,7 +1019,7 @@ class UpdateLeafIndices : public IterVisitor { extent_map_(std::move(extent_map)) { const std::vector domain_vals(td_->leaf().begin(), td_->leaf().end()); - traverseTo(td_->fusion(), domain_vals, false); + traverseTo(domain_vals, false); } const std::unordered_map& indexMap() const { diff --git a/csrc/ir/cloner.cpp b/csrc/ir/cloner.cpp index 1fdac07954e..87898b3e8e2 100644 --- a/csrc/ir/cloner.cpp +++ b/csrc/ir/cloner.cpp @@ -59,8 +59,7 @@ TensorView* RecomputeTv::recompute( "Cannot recompute buffers that are inputs of the fusion."); // Grab all the expressions used to generate the TensorView - auto exprs = - StmtSort::getExprsBetween(tv->fusion(), from, {tv}, false, false); + auto exprs = StmtSort::getExprsBetween(from, {tv}, false, false); // Run the replicator RecomputeTv replicator(tv->fusion()); @@ -91,7 +90,7 @@ TensorView* RecomputeTv::recompute( return cloned_val->as(); } -RecomputeTv::RecomputeTv(Fusion* fusion) : IrCloner(fusion), fusion_(fusion) { +RecomputeTv::RecomputeTv(Fusion* fusion) : IrCloner(fusion) { // Add inputs to the clones map to prevent cloning them. for (const auto inp : fusion->inputs()) { clones_map_[inp] = inp; @@ -115,8 +114,7 @@ Statement* RecomputeTv::handle(const Statement* s) { Statement* RecomputeTv::handle(const TensorDomain* td) { // Make sure to recompute the history of the iteration domains, explicitly go // through the expressions and send them to IrCloner. - auto exprs = - StmtSort::getExprsTo(fusion_, {td->leaf().begin(), td->leaf().end()}); + auto exprs = StmtSort::getExprsTo({td->leaf().begin(), td->leaf().end()}); for (auto expr : exprs) { IrCloner::handle(expr); diff --git a/csrc/ir/cloner.h b/csrc/ir/cloner.h index 7997a239e94..9a3e4ec95cd 100644 --- a/csrc/ir/cloner.h +++ b/csrc/ir/cloner.h @@ -128,8 +128,6 @@ class RecomputeTv : private IrCloner { RecomputeTv(Fusion* fusion); Statement* handle(const Statement* s) override; Statement* handle(const TensorDomain*); - - Fusion* fusion_; }; //! Clone an IR node, forwarding the arguments to the IrCloner constructor. diff --git a/csrc/ir/iostream.cpp b/csrc/ir/iostream.cpp index 824e6cb16f1..b670d196c4f 100644 --- a/csrc/ir/iostream.cpp +++ b/csrc/ir/iostream.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 4cba0999ede..e0961bb8da6 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -3026,7 +3026,7 @@ std::pair IterDomain::swizzle( !in_x->isReduction() && !in_y->isReduction(), "swizzled reduction not yet supported"); - for (auto input : InputsOf::outputs(in_x->fusion(), {in_x, in_y})) { + for (auto input : InputsOf::outputs({in_x, in_y})) { NVF_CHECK( !input->as()->isBroadcast(), "swizzling broadcast axes not yet supported"); diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index ccbe0679b03..50b11ff0896 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -460,7 +460,7 @@ class ValReplacementMutator : private OptOutMutator { // typically not used by anything else. If we don't grab that count, then it // would be a tensorview that doesn't get updated extents. Therefore, first // grab all leaves towards outputs and grab stmts from there. - auto stmts = StmtSort::getStmtsTo(fusion, allLeafOuts(fusion), true, true); + auto stmts = StmtSort::getStmtsTo(allLeafOuts(fusion), true, true); // Some fusions, such as standalone rand_like, can have disconnected DAG, so // we need some mechanism to make sure our replacement set is as complete as @@ -478,7 +478,7 @@ class ValReplacementMutator : private OptOutMutator { more.emplace_back(v); } } - auto more_stmts = StmtSort::getStmtsTo(fusion, more, true, true); + auto more_stmts = StmtSort::getStmtsTo(more, true, true); more_stmts.insert(more_stmts.end(), stmts.begin(), stmts.end()); for (auto stmt : more_stmts) { @@ -797,7 +797,6 @@ bool hasResizedRfactor(const TensorView* tv) { return false; } auto root_to_rf_exprs = StmtSort::getExprsBetween( - tv->fusion(), {tv->getRootDomain().begin(), tv->getRootDomain().end()}, {tv->getRFactorDomain().begin(), tv->getRFactorDomain().end()}); return std::any_of( @@ -840,7 +839,6 @@ class ValidateDomainEquivalence : private IterVisitor { toDelimitedString(derived_domain)); traverseBetween( - initial_domain.at(0)->fusion(), {initial_domain.begin(), initial_domain.end()}, {derived_domain.begin(), derived_domain.end()}); diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 0c3be5a2c7e..68f52787933 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -126,7 +126,7 @@ void IterVisitor::dispatch(Val* v) { // Implementation details: // We start with an entry in stmt_stack that is the outputs we want to -// process. We cannot process these outputs untill all Stmts in their history +// process. We cannot process these outputs until all Stmts in their history // have been processed, as those Stmts contain all dependencies to produce // these values. What we will do is traverse towards inputs until we hit a // leaf node. Once we hit a leaf node that node will be visited, then we will @@ -138,13 +138,16 @@ void IterVisitor::dispatch(Val* v) { // function to remove visited nodes from being re-added to the stack // (remove_visited). void IterVisitor::traverseBetween( - Fusion* fusion, const std::unordered_set& from, const std::vector& to, bool traverse_all_paths, bool traverse_into_members, bool traverse_attributes, bool traverse_siblings) { + if (to.empty()) { + return; + } + Fusion* fusion = to.front()->fusion(); FusionGuard fg(fusion); std::unordered_set visited; @@ -287,14 +290,12 @@ void IterVisitor::traverseBetween( } void IterVisitor::traverseTo( - Fusion* fusion, const std::vector& to, bool traverse_all_paths, bool traverse_into_members, bool traverse_attributes, bool traverse_siblings) { traverseBetween( - fusion, {}, to, traverse_all_paths, @@ -308,7 +309,7 @@ void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) { auto term_val_outs = fusion->getTerminatingOutputs(); if (!term_val_outs.empty()) { - traverseTo(fusion, term_val_outs, traverse_all_paths); + traverseTo(term_val_outs, traverse_all_paths); } } @@ -364,7 +365,7 @@ class Inputs : public IterVisitor { return {}; } Inputs inps(all_inputs); - inps.traverseTo(of[0]->fusion(), of); + inps.traverseTo(of); return inps.inputs_; } }; @@ -393,7 +394,7 @@ class AllVals : public IterVisitor { Fusion* fusion, const std::vector& from) { AllVals av; - av.traverseTo(fusion, from, false); + av.traverseTo(from, false); return av.vals; } }; @@ -451,21 +452,20 @@ void BackwardVisitor::dispatch(Val* val) { } void BackwardVisitor::traverseTo( - Fusion* fusion, const std::vector& from, bool traverseAllPaths) { + if (from.empty()) { + return; + } + Fusion* fusion = from.front()->fusion(); FusionGuard fg(fusion); // Reset members stmt_stack_.clear(); traversal_exprs_.clear(); - if (from.empty()) { - return; - } - auto vals = AllVals::get(fusion, from); - auto exprs = StmtSort::getExprsTo(fusion, from); + auto exprs = StmtSort::getExprsTo(from); { size_t pos = 0; @@ -603,7 +603,7 @@ struct Dependencies : public IterVisitor { std::unordered_set _dependencies, const std::vector& of) : dependencies_(std::move(_dependencies)) { - traverseTo(of[0]->fusion(), of, false); + traverseTo(of, false); }; public: @@ -650,7 +650,7 @@ struct FindOutputs : public IterVisitor { // tracing all paths like this. FindOutputs(const std::unordered_set& _of) : of_(_of) { auto fusion = (*of_.begin())->fusion(); - traverseTo(fusion, fusion->outputs(), true); + traverseTo(fusion->outputs(), true); }; static std::unordered_set getAllOutputsOf( @@ -719,7 +719,7 @@ class DependentVals : public IterVisitor { DependentVals(const std::unordered_set& _of) : of_(_of) { createBoundary(); auto fusion = (*of_.begin())->fusion(); - traverseTo(fusion, fusion->outputs(), false); + traverseTo(fusion->outputs(), false); }; public: @@ -755,7 +755,7 @@ class DependencyChains : public IterVisitor { DependencyChains(Val* _dependency, Val* _of, bool all_chains_ = false) : dependencies_({_dependency}) { - traverseTo(_of->fusion(), {_of}, all_chains_); + traverseTo({_of}, all_chains_); } DependencyChains(Val* _dependency, bool all_chains_ = false) @@ -882,7 +882,6 @@ std::vector StmtSort::getExprs( bool traverse_siblings) { auto terminating_outputs = fusion->getTerminatingOutputs(); return StmtSort::getExprsTo( - fusion, terminating_outputs, traverse_members, traverse_attributes, @@ -890,32 +889,25 @@ std::vector StmtSort::getExprs( } std::vector StmtSort::getExprsTo( - Fusion* fusion, const std::vector& to, bool traverse_members, bool traverse_attributes, bool traverse_siblings) { auto stmts = StmtSort::getStmtsTo( - fusion, to, traverse_members, traverse_attributes, traverse_siblings); + to, traverse_members, traverse_attributes, traverse_siblings); auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); std::vector exprs(filter.begin(), filter.end()); return exprs; } std::vector StmtSort::getExprsBetween( - Fusion* fusion, const std::vector& from, const std::vector& to, bool traverse_members, bool traverse_attributes, bool traverse_siblings) { auto stmts = StmtSort::getStmtsBetween( - fusion, - from, - to, - traverse_members, - traverse_attributes, - traverse_siblings); + from, to, traverse_members, traverse_attributes, traverse_siblings); auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); std::vector exprs(filter.begin(), filter.end()); return exprs; @@ -928,7 +920,6 @@ std::vector StmtSort::getStmts( bool traverse_siblings) { auto terminating_outputs = fusion->getTerminatingOutputs(); return StmtSort::getStmtsTo( - fusion, terminating_outputs, traverse_members, traverse_attributes, @@ -936,24 +927,17 @@ std::vector StmtSort::getStmts( } std::vector StmtSort::getStmtsTo( - Fusion* fusion, const std::vector& to, bool traverse_members, bool traverse_attributes, bool traverse_siblings) { StmtSort es; es.traverseTo( - fusion, - to, - false, - traverse_members, - traverse_attributes, - traverse_siblings); + to, false, traverse_members, traverse_attributes, traverse_siblings); return es.stmts; } std::vector StmtSort::getStmtsBetween( - Fusion* fusion, const std::vector& from, const std::vector& to, bool traverse_members, @@ -961,7 +945,6 @@ std::vector StmtSort::getStmtsBetween( bool traverse_siblings) { StmtSort es; es.traverseBetween( - fusion, {from.begin(), from.end()}, to, false, @@ -979,15 +962,13 @@ void InputsOf::dispatch(Val* v) { } } -std::vector InputsOf::output(Fusion* fusion, Val* output_) { - return outputs(fusion, {output_}); +std::vector InputsOf::output(Val* output_) { + return outputs({output_}); } -std::vector InputsOf::outputs( - Fusion* fusion, - const std::vector& outputs_) { +std::vector InputsOf::outputs(const std::vector& outputs_) { InputsOf io; - io.traverseTo(fusion, outputs_, false); + io.traverseTo(outputs_, false); return io.ordered_inputs; } @@ -995,14 +976,14 @@ std::vector InputsOf::outputs( bool DeadCodeRemover::run() { // First we build a set of all live Statements so that we can detect dead // branches. - for (auto stmt : StmtSort::getStmtsTo(fusion_, fusion_->outputs())) { + for (auto stmt : StmtSort::getStmtsTo(fusion_->outputs())) { markLive(stmt); } // Note that StmtSort::getStmtsTo() is also run in traverseTo. In the future, // we could potentially refactor this so that derived classes from // BackwardVisitor can make use of that traversal instead of repeating it. - traverseTo(fusion_, fusion_->outputs(), false); + traverseTo(fusion_->outputs(), false); // We do not remove Statements from the Fusion while traversing, to avoid // dereferencing invalid pointers. Instead, we wait until this point to do the diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index 2a0e5b92188..f1f6610b854 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -99,7 +99,6 @@ class IterVisitor : public OptOutDispatch { //! active multi-output expressions, even if those Expr outputs are not used //! in paths to Fusion outputs. void traverseTo( - Fusion* fusion, const std::vector& to, bool traverse_all_paths = false, bool traverse_into_members = false, @@ -126,7 +125,6 @@ class IterVisitor : public OptOutDispatch { //! active multi-output expressions, even if those Expr outputs are not used //! in paths to Fusion outputs. void traverseBetween( - Fusion* fusion, const std::unordered_set& from, const std::vector& to, bool traverse_all_paths = false, @@ -238,10 +236,7 @@ class BackwardVisitor : public OptOutDispatch { // traverseAllPaths = false only call handle on each Statement* once // traverseAllPaths = true traverses all paths from nodes in from to inputs. // Handle on a Statement* for every path from "from" nodes, to inputs. - void traverseTo( - Fusion* fusion, - const std::vector& from, - bool traverseAllPaths = false); + void traverseTo(const std::vector& from, bool traverseAllPaths = false); bool must_cover_all_expr_outputs_ = true; }; @@ -313,7 +308,6 @@ class StmtSort : public IterVisitor { // Returns ordered Statements required to produce 'to', including 'to'. static std::vector getStmtsTo( - Fusion* fusion, const std::vector& to, bool traverse_members = false, bool traverse_attributes = false, @@ -337,7 +331,6 @@ class StmtSort : public IterVisitor { // If traverse_members it will also extract all member nodes in the sorted // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc static std::vector getStmtsBetween( - Fusion* fusion, const std::vector& from, const std::vector& to, bool traverse_members = false, @@ -353,7 +346,6 @@ class StmtSort : public IterVisitor { // Same as getStmts version but filters to only return the Expr*s static std::vector getExprsTo( - Fusion* fusion, const std::vector& to, bool traverse_members = false, bool traverse_attributes = false, @@ -361,7 +353,6 @@ class StmtSort : public IterVisitor { // Same as getStmts version but filters to only return the Expr*s static std::vector getExprsBetween( - Fusion* fusion, const std::vector& from, const std::vector& to, bool traverse_members = false, @@ -379,10 +370,8 @@ class InputsOf : public IterVisitor { void dispatch(Val* v) final; public: - static std::vector output(Fusion* fusion, Val* output_); - static std::vector outputs( - Fusion* fusion, - const std::vector& outputs_); + static std::vector output(Val* output_); + static std::vector outputs(const std::vector& outputs_); }; //! This is a generic traversal class that is used to modify a Fusion graph by diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 3d8ae73a430..3a06355c77c 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -116,7 +116,7 @@ std::vector PipelineExecutor::runWithInput( } // Run through the stages to launch kernel - traverseTo(runtime_.pipeline_, runtime_.pipeline_->outputs()); + traverseTo(runtime_.pipeline_->outputs()); // Collect global outputs from context std::vector outputs; diff --git a/csrc/multidevice/pipeline.cpp b/csrc/multidevice/pipeline.cpp index ffe208b7250..e5a62cbeba2 100644 --- a/csrc/multidevice/pipeline.cpp +++ b/csrc/multidevice/pipeline.cpp @@ -282,7 +282,7 @@ class PipelinePrinter : public IterVisitor { string_ << "}\n"; string_ << "Pipeline's Traversal inputs --> outputs {\n"; - traverseTo(pipeline_, pipeline_->outputs()); + traverseTo(pipeline_->outputs()); string_ << "}\n"; string_ << "Pipeline's outputs:{\n"; diff --git a/csrc/non_divisible_split.cpp b/csrc/non_divisible_split.cpp index 406cb5525f4..ad741004573 100644 --- a/csrc/non_divisible_split.cpp +++ b/csrc/non_divisible_split.cpp @@ -26,7 +26,7 @@ void NonDivisibleSplitInfo::build(Fusion* fusion) { tv->getLeafDomain().begin(), tv->getLeafDomain().end()); current_tv_ = tv; clearReachability(); - traverseTo(fusion, domain_vals); + traverseTo(domain_vals); current_tv_ = nullptr; } diff --git a/csrc/partial_split_map.cpp b/csrc/partial_split_map.cpp index 9540c6b08a2..3eb23222e3a 100644 --- a/csrc/partial_split_map.cpp +++ b/csrc/partial_split_map.cpp @@ -16,7 +16,7 @@ void PartialSplitMap::build(Fusion* fusion) { for (auto tv : ir_utils::filterByType(used_vals)) { auto exprs = StmtSort::getExprsTo( - fusion, {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); + {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // Only needs to check root domains as partial split is only // allowed with root domains diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 18d1269f7e0..05eff3f7f78 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -322,7 +322,7 @@ class FindInputDomains : BackwardVisitor { } DomainKeySet find() { - traverseTo(tv_->fusion(), {tv_}); + traverseTo({tv_}); return input_keys_; } @@ -782,7 +782,7 @@ ComputeAtRootDomainMapBuilder::ComputeAtRootDomainMapBuilder( map_through_reduction_(map_through_reduction) { Fusion* fusion = FusionGuard::getCurFusion(); NVF_ERROR(fusion != nullptr); - traverseTo(fusion, fusion->outputs(), false); + traverseTo(fusion->outputs(), false); if (!pending_map_.empty()) { std::stringstream ss; ss << "pending map:\n"; @@ -1241,7 +1241,7 @@ class ExactRootDomainMapBuilder : private IterVisitor { Fusion* fusion, DisjointSets& eq_sets) : eq_sets_(eq_sets) { - traverseTo(fusion, fusion->outputs()); + traverseTo(fusion->outputs()); } private: diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 07eb09a1723..fe3a5b6f457 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -142,7 +142,7 @@ struct TransposeViewPropagator : public MaxInfoSpanningTree::Propagator { // propagation travelling across view op. Note this is a conservative check, // since view does NOT necessarily always introduce incoherent transform // that would break the propagation. - auto chain_exprs = StmtSort::getExprsBetween(from->fusion(), {from}, {to}); + auto chain_exprs = StmtSort::getExprsBetween({from}, {to}); if (!ir_utils::filterByType(chain_exprs).empty()) { should_reject = true; }; @@ -239,9 +239,7 @@ class DomainMap : public pointwise_utils::DomainMap { " in tensor ", tv); auto replay_exprs = StmtSort::getExprsBetween( - tv->fusion(), - {mapped_id}, - {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); + {mapped_id}, {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); // Project the root id to leaf id. Similar to projectIdToRFactor. for (auto expr : replay_exprs) { if (expr->isA()) { diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 5fee073af39..70775a55939 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1110,7 +1110,7 @@ IterDomain* projectIdToRoot( return reference_id; } - auto replay_exprs = StmtSort::getExprsTo(tv->fusion(), {reference_id}); + auto replay_exprs = StmtSort::getExprsTo({reference_id}); if (replay_exprs.empty()) { return reference_id; } @@ -1176,9 +1176,7 @@ IterDomain* projectIdToRFactor( } auto replay_exprs = StmtSort::getExprsTo( - tv->fusion(), - {tv->getRFactorDomain().begin(), tv->getRFactorDomain().end()}, - false); + {tv->getRFactorDomain().begin(), tv->getRFactorDomain().end()}, false); if (replay_exprs.empty()) { return reference_id; } @@ -1853,7 +1851,6 @@ DisjointSets disjointRFactorSets(Fusion* fusion) { // rfactor domains they should be considered "contaminated". for (auto tv : ir_utils::allTvs(fusion)) { for (auto expr : StmtSort::getExprsTo( - fusion, {tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()})) { if (expr->isA()) { @@ -1904,7 +1901,7 @@ bool breakIsDisjoint(std::vector group_ids, int pos) { std::unordered_map domainReorderAsRfactorMap(TensorView* tv) { FusionGuard fg(tv->fusion()); auto transform_exprs = StmtSort::getExprsTo( - tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); + {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); // simply update this vector of id's as progressing through the transformation // expressions. We'll always insert the result of split in the location of the // input, and insert the merge result in the position of the inner dimension. @@ -1987,7 +1984,6 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { // rfactor domains they should be considered "contaminated". for (auto tv : ir_utils::allTvs(fusion)) { for (auto expr : StmtSort::getExprsBetween( - fusion, {tv->getRootDomain().begin(), tv->getRootDomain().end()}, {tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()})) { diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index be18e29ffb2..6d4ceb54ec5 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -377,9 +377,7 @@ std::vector ContiguousInnerDimensionsMapper::projectId( // empty backward exprs, vice versa. auto backward_exprs = StmtSort::getExprsBetween( - frontier.front()->fusion(), - {to.begin(), to.end()}, - {frontier.begin(), frontier.end()}); + {to.begin(), to.end()}, {frontier.begin(), frontier.end()}); // Mapping from rfactor to root, reverse expressions std::reverse(backward_exprs.begin(), backward_exprs.end()); @@ -407,9 +405,7 @@ std::vector ContiguousInnerDimensionsMapper::projectId( } auto forward_exprs = StmtSort::getExprsBetween( - frontier.front()->fusion(), - {frontier.begin(), frontier.end()}, - {to.begin(), to.end()}); + {frontier.begin(), frontier.end()}, {to.begin(), to.end()}); // Map forward through transforms since we're going from root to rfactor for (auto* expr : forward_exprs) { diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 86406e01034..34d2a86930d 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -106,9 +106,7 @@ class ForwardTraverseFromRFactorToAlloc { const std::vector& rfactor, const std::vector& alloc) { auto forward_exprs = StmtSort::getExprsBetween( - tv->fusion(), - {rfactor.begin(), rfactor.end()}, - {alloc.begin(), alloc.end()}); + {rfactor.begin(), rfactor.end()}, {alloc.begin(), alloc.end()}); for (auto expr : forward_exprs) { handle(expr); } @@ -201,9 +199,7 @@ class BackwardTraverseFromRFactorToAlloc { const std::vector& rfactor, const std::vector& alloc) { auto backward_exprs = StmtSort::getExprsBetween( - tv->fusion(), - {alloc.begin(), alloc.end()}, - {rfactor.begin(), rfactor.end()}); + {alloc.begin(), alloc.end()}, {rfactor.begin(), rfactor.end()}); std::reverse(backward_exprs.begin(), backward_exprs.end()); for (auto expr : backward_exprs) { handle(expr); diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 6c9657b2f80..ae57acd474c 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -820,7 +820,7 @@ TensorView* TensorView::swizzle( // Disable unsupported use cases at the current step. // Currently do not support reducing or broadcasting // swizzled dimensions. - auto all_inputs = InputsOf::outputs(fusion(), {axis(x), axis(y)}); + auto all_inputs = InputsOf::outputs({axis(x), axis(y)}); for (auto id : ir_utils::filterByType(all_inputs)) { NVF_ERROR( !id->isBroadcast() && !id->isReduction(), diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index d5490c29edd..76d256f5083 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -255,7 +255,7 @@ void ReplayTransformations::runReplay() { // Switch outDomain to a vector to start the traversal std::vector traversal_vals( target_domain_.begin(), target_domain_.end()); - traverseTo(traversal_vals[0]->fusion(), traversal_vals); + traverseTo(traversal_vals); if (error_on_failure_) { NVF_ERROR( @@ -319,9 +319,8 @@ BestEffortReplay::BestEffortReplay( } // Grab expr history of iter domains in target_domain - std::vector target_exprs = StmtSort::getExprsTo( - FusionGuard::getCurFusion(), - std::vector(target_domain.begin(), target_domain.end())); + std::vector target_exprs = + StmtSort::getExprsTo({target_domain.begin(), target_domain.end()}); // If we check how an IterDomain was generated, it should only use an // IterDomain in an expression once. We pull a map from the input @@ -330,9 +329,8 @@ BestEffortReplay::BestEffortReplay( // replay_domain map. // Map replay domain's IterDomains to the Exprs they're used in - std::vector replay_exprs = StmtSort::getExprsTo( - FusionGuard::getCurFusion(), - std::vector(replay_domain.begin(), replay_domain.end())); + std::vector replay_exprs = + StmtSort::getExprsTo({replay_domain.begin(), replay_domain.end()}); // Track which id's in replay have to be replayed to guarantee rfactor // transformations. The iteration domains in the rfactor axes don't have @@ -750,10 +748,7 @@ struct ForwardingInfo { // now forward those to include all id's in active_tv comprised of only axes // not in the inactive tensor. std::vector active_tv_history = StmtSort::getExprsTo( - FusionGuard::getCurFusion(), - std::vector( - active_tv->getLeafDomain().begin(), - active_tv->getLeafDomain().end())); + {active_tv->getLeafDomain().begin(), active_tv->getLeafDomain().end()}); auto isIdOnlyInActiveTv = [&forwarded_ids](IterDomain* input_id) { return forwarded_ids.count(input_id) > 0; @@ -900,8 +895,8 @@ void BestEffortReplay::addComplimentLeafIDs( } // Grab all exprs used to make the forwarded compliments - auto compliment_exprs = StmtSort::getExprsTo( - FusionGuard::getCurFusion(), {compliments.begin(), compliments.end()}); + auto compliment_exprs = + StmtSort::getExprsTo({compliments.begin(), compliments.end()}); // Figure out if there are any leaves in compliment_exprs that aren't // the forwarded id diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 28305504542..9e741a8d410 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -6026,7 +6026,7 @@ TEST_F(NVFuserTest, FusionPropagateVectorizePredicate_CUDA) { // Make sure the index of the inner loop isn't used in the predicate NVF_ERROR(!for_loops_.empty()); auto loop_index = for_loops_.back()->index(); - auto cond_inputs = InputsOf::output(cond->fusion(), cond); + auto cond_inputs = InputsOf::output(cond); auto index_it = std::find(cond_inputs.begin(), cond_inputs.end(), loop_index); auto vec_factor_it = diff --git a/test/test_iter_visitor.cpp b/test/test_iter_visitor.cpp index 6e968610824..a0e5eacd2c4 100644 --- a/test/test_iter_visitor.cpp +++ b/test/test_iter_visitor.cpp @@ -67,7 +67,6 @@ TEST_F(IterVisitorTest, IterVisitorTraverseSiblings) { // Test getting statements "to" a tensor with siblings stmts = StmtSort::getStmtsTo( - &fusion, {wf.n}, /*traverse_all_paths=*/false, /*traverse_attributes=*/false, @@ -116,8 +115,7 @@ TEST_F(IterVisitorTest, NonTerminatingOutput) { // Even though `c` is a non-terminating output, `d` and `e` should still be // considered in between. This is because `StmtSort::getExprsBetween` // traverses from `to` along use-def chains until it hits `from`. - EXPECT_THAT(StmtSort::getExprsBetween(&fusion, {a}, {c, e}), - IsSupersetOf({d->definition(), e->definition()})); + EXPECT_THAT(StmtSort::getExprsBetween({a}, {c, e}), IsSupersetOf({d->definition(), e->definition()})); } } // namespace nvfuser diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index e81e676f911..b4c961e1f76 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -75,9 +75,7 @@ TEST_F(NVFuserTest, FusionCyclicGraph_CUDA) { ir_utils::checkCycle(fusion.get()).size() == 6, "cycle of size 6 should be detected in fusion"); EXPECT_THAT( - [&]() { - StmtSort::getStmtsBetween(fusion.get(), {}, fusion->outputs()); - }, + [&]() { StmtSort::getStmtsBetween({}, fusion->outputs()); }, ::testing::ThrowsMessage( ::testing::HasSubstr("cycle detected"))); } @@ -115,7 +113,7 @@ TEST_F(NVFuserTest, FusionCyclicGraph_CUDA) { to.push_back(tv1); // cycle should be detected, since dead branch is in our check path EXPECT_THAT( - [&]() { StmtSort::getStmtsBetween(fusion.get(), {}, to); }, + [&]() { StmtSort::getStmtsBetween({}, to); }, ::testing::ThrowsMessage( ::testing::HasSubstr("cycle detected"))); diff --git a/test/test_swizzle.cpp b/test/test_swizzle.cpp index a6d70bd26bb..5a1c1281f9f 100644 --- a/test/test_swizzle.cpp +++ b/test/test_swizzle.cpp @@ -644,7 +644,6 @@ TEST_F(SwizzleTest, TransformPropagatorSkipSwizzleOnTarget) { MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); auto exprs = StmtSort::getExprsBetween( - tv1->fusion(), {tv1->getRootDomain().begin(), tv1->getRootDomain().end()}, {tv1->getLeafDomain().begin(), tv1->getLeafDomain().end()}); EXPECT_TRUE(std::any_of(exprs.begin(), exprs.end(), [](Expr* expr) {