Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,6 @@ void IterDomainGraph::build(Fusion* fusion) {
// Grab all the rfactor ids.
for (auto consumer_tv : all_consumer_tvs) {
auto exprs = StmtSort::getExprsTo(
fusion,
{consumer_tv->getMaybeRFactorDomain().begin(),
consumer_tv->getMaybeRFactorDomain().end()});
for (auto expr : exprs) {
Expand Down
12 changes: 3 additions & 9 deletions csrc/contiguity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ OrderedIdInformation::OrderedIdInformation(
// consistently_ordered_ids_, id_to_alloc_ids_, and
// exclusively_consumes_allocs_ for all the IDs
auto exprs = StmtSort::getExprsBetween(
ids[0]->fusion(),
{alloc_domain.begin(), alloc_domain.end()},
{ids.begin(), ids.end()});
{alloc_domain.begin(), alloc_domain.end()}, {ids.begin(), ids.end()});

for (auto expr : exprs) {
OptInDispatch::dispatch(expr);
Expand Down Expand Up @@ -386,9 +384,7 @@ NonDivisibleSplitDependencies::NonDivisibleSplitDependencies(
return;
}
auto transforms = StmtSort::getExprsBetween(
ids[0]->fusion(),
{alloc_domain.begin(), alloc_domain.end()},
{ids.begin(), ids.end()});
{alloc_domain.begin(), alloc_domain.end()}, {ids.begin(), ids.end()});
for (auto transform : transforms) {
auto inp_ids = ir_utils::filterByType<IterDomain>(transform->inputs());
for (auto inp_id : inp_ids) {
Expand Down Expand Up @@ -545,9 +541,7 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {

if (!contig_ids_.empty()) {
auto exprs = StmtSort::getExprsBetween(
ids.at(0)->fusion(),
{alloc_domain_.begin(), alloc_domain_.end()},
{ids.begin(), ids.end()});
{alloc_domain_.begin(), alloc_domain_.end()}, {ids.begin(), ids.end()});
for (auto expr : exprs) {
if (auto resize = dynamic_cast<Resize*>(expr)) {
resize_deps_.insert(resize->out());
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/analysis/divisible_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ std::unordered_set<Split*> getAllDivisibleSplits(
// Take the view transformations and add all the splits. Those splits are
// the only divisible splits.
auto view_exprs =
StmtSort::getExprsTo(fusion, {rfactor_dom.begin(), rfactor_dom.end()});
StmtSort::getExprsTo({rfactor_dom.begin(), rfactor_dom.end()});
auto split_exprs = ir_utils::filterByType<Split>(view_exprs);
all_divisible_splits.insert(split_exprs.begin(), split_exprs.end());
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/device_lower/analysis/predicate_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ class PredicateChcker : public IterVisitor {
// provided.
bool predicateNonDivisibleRootDomains(Expr* expr) const {
for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
const auto all_exprs = DependencyCheck::getAllExprsBetween(
const auto all_exprs = StmtSort::getExprsBetween(
{output->getMaybeRFactorDomain().begin(),
output->getMaybeRFactorDomain().end()},
{output->getLeafDomain().begin(), output->getLeafDomain().end()});
Expand Down Expand Up @@ -863,7 +863,7 @@ class PredicateChcker : public IterVisitor {
} // namespace

PredicateElimination::PredicateElimination(Fusion* fusion) {
traverseTo(fusion, fusion->outputs());
traverseTo(fusion->outputs());
}

bool PredicateElimination::needsPredicate(Expr* expr) const {
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/analysis/shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ void HaloInfo::setHaloWidth(IterDomain* id, int halo_width) {

// Propagate extent information from root axes to descendants
void HaloInfo::build(TensorDomain* td) {
auto exprs = DependencyCheck::getAllExprsBetween(
auto exprs = StmtSort::getExprsBetween(
{td->maybeRFactor().begin(), td->maybeRFactor().end()},
{td->leaf().begin(), td->leaf().end()});

Expand Down
4 changes: 1 addition & 3 deletions csrc/device_lower/analysis/sync_information.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ struct ProducerConsumerIndexingInfoCache {
const auto& consumer_leaf_ids_shared_with_producer =
getConsumerLeafIDsSharedWithProducer();
consumer_root_ids_shared_with_producer_ = InputsOf::outputs(
producer_tv_->fusion(),
{consumer_leaf_ids_shared_with_producer.begin(),
consumer_leaf_ids_shared_with_producer.end()});
}
Expand Down Expand Up @@ -261,10 +260,9 @@ bool useSameIndex(
// consumer_id. The goal of the analysis below is to find out if all
// of the root IDs are indexed in the same way between the producer
// and consumer tensors.
auto consumer_root_ids = InputsOf::output(consumer_id->fusion(), consumer_id);
auto consumer_root_ids = InputsOf::output(consumer_id);

auto producer_root_vals = StmtSort::getStmtsBetween(
producer_id->fusion(),
{producer_tv->getMaybeRFactorDomain().begin(),
producer_tv->getMaybeRFactorDomain().end()},
{producer_id});
Expand Down
15 changes: 8 additions & 7 deletions csrc/device_lower/analysis/thread_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {

// Run through inputs and update bitsets
for (const auto* inp : expr->inputs()) {
if (!ir_utils::isTV(inp))
if (!ir_utils::isTV(inp)) {
continue;
}

auto tv_inp = inp->as<TensorView>();

Expand Down Expand Up @@ -365,7 +366,7 @@ class RedundantUseAnalysis : BackwardVisitor {
public:
RedundantUseAnalysis(Fusion* fusion, const ThreadPredicateMap& pred_map)
: fusion_(fusion), pred_map_(pred_map) {
traverseTo(fusion, fusion->terminatingMathVals());
traverseTo(fusion->terminatingMathVals());
}

//! Returns a bit map signifying the parallel dimensions
Expand Down Expand Up @@ -619,14 +620,14 @@ class ConcretizedBroadcastRedundantWriteRemover {

// Find all the root domains that are merged to the leaf domain.
// e.g. Root: [I1,B2,B3] -> Leaf: [I1*B2*B3]
std::vector<IterDomain*> getRootDomainsMergedToLeaf(IterDomain* ld) {
std::vector<IterDomain*> getRootDomainsMergedToLeaf(IterDomain* id) {
std::vector<IterDomain*> merged_root_domains;
std::vector<int> index_root_domain;
std::vector<IterDomain*> intermediate_domains = root_domain_;
auto all_exp = DependencyCheck::getAllExprsBetween(
{root_domain_.begin(), root_domain_.end()}, {ld});
for (auto expr : all_exp) {
if (auto merge = dynamic_cast<Merge*>(expr)) {
auto all_exp = StmtSort::getExprsBetween(
{root_domain_.begin(), root_domain_.end()}, {id});
for (Expr* expr : all_exp) {
if (auto* merge = dynamic_cast<Merge*>(expr)) {
auto outer_iter =
std::find(root_domain_.begin(), root_domain_.end(), merge->outer());
auto inner_iter =
Expand Down
3 changes: 1 addition & 2 deletions csrc/device_lower/pass/alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ bool isSerialBroadcastResolution(
// traverse across view boundaries as we do in indexing. This
// should not result in false aliasing but may miss safe aliasing
// opportunities.
auto serial_loop_roots =
InputsOf::outputs(FusionGuard::getCurFusion(), serial_loop_concrete_ids);
auto serial_loop_roots = InputsOf::outputs(serial_loop_concrete_ids);

// Collect exact concrete id's in producer's root domain
std::unordered_set<IterDomain*> producer_exact_concrete_root_ids;
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class AllocationInserter : public kir::ExprMutator {
[](IterDomain* dom) { return dom->as<Val>(); });

// Get all exprs involved in generating the allocation IDs
auto exprs = StmtSort::getExprsTo(tv->fusion(), start_vals);
auto exprs = StmtSort::getExprsTo(start_vals);

// Get the halo extent if found
auto getExtent = [this](IterDomain* id) {
Expand Down
10 changes: 5 additions & 5 deletions csrc/device_lower/pass/expr_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,16 @@ std::string ExprGroup::toString() const {
os << " ca_ids {";
for (size_t i = 0; i < payload()->ca_domains.size(); i++) {
os << payload()->ca_domains[i];
if (i + 1 != payload()->ca_domains.size())
if (i + 1 != payload()->ca_domains.size()) {
os << ", ";
}
}
os << "} pa_ids {";
for (size_t i = 0; i < payload()->pa_domains.size(); i++) {
os << payload()->pa_domains[i];
if (i + 1 != payload()->pa_domains.size())
if (i + 1 != payload()->pa_domains.size()) {
os << ", ";
}
}
os << "}";
os << "\nExprs {\n";
Expand Down Expand Up @@ -1507,9 +1509,7 @@ void ExprSegmentationSorter::sort() {
// Not putting the exprs between allKnownVals() and fusion inputs here
// because they are computed using the expr evaluator.
auto all_exprs = StmtSort::getExprsBetween(
fusion_,
GpuLower::current()->allKnownVals(),
fusion_->getTerminatingOutputs());
GpuLower::current()->allKnownVals(), fusion_->getTerminatingOutputs());

// Figure out all the values used as inputs to the expressions we're sorting
// (to find terminating expressions). There could be branches of expressions
Expand Down
3 changes: 1 addition & 2 deletions csrc/device_lower/pass/warp_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ class EliminateDeadBroadcastAndAllocate {
// Also find any TVs used in index expressions.
// These expressions will likely not be in the Expr tree we are
// provided, so we need to traverse to find them.
auto all_index_roots =
InputsOf::outputs(FusionGuard::getCurFusion(), {ti->index()});
auto all_index_roots = InputsOf::outputs({ti->index()});
auto index_root_tis =
ir_utils::filterByType<kir::TensorIndex>(all_index_roots);
for (auto rootti : index_root_tis) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ std::vector<Expr*> replaceInputsInExpr(
std::vector<Expr*> getAllSwizzlesBetween(
std::vector<IterDomain*> from,
std::vector<IterDomain*> to) {
auto all_expr = DependencyCheck::getAllExprsBetween(
auto all_expr = StmtSort::getExprsBetween(
{from.begin(), from.end()}, {to.begin(), to.end()});

std::vector<Expr*> all_swizzles;
Expand Down
5 changes: 2 additions & 3 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class VectorizeValidator : public OptInDispatch {
IterDomain* v_id,
TensorView* tv,
std::string name) {
auto replay_exprs = DependencyCheck::getAllExprsBetween(
auto replay_exprs = StmtSort::getExprsBetween(
{tv->getMaybeAllocationDomain().begin(),
tv->getMaybeAllocationDomain().end()},
{v_id});
Expand Down Expand Up @@ -836,7 +836,7 @@ void validatePartialSplit(Fusion* fusion) {

for (auto tv : ir_utils::allTvs(fusion)) {
auto exprs = StmtSort::getExprsTo(
tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()});
{tv->getLeafDomain().begin(), tv->getLeafDomain().end()});
for (auto split : ir_utils::filterByType<Split>(exprs)) {
// When the start and stop offsets are not zero, make sure the
// range defined by the split includes the required range to
Expand Down Expand Up @@ -1255,7 +1255,6 @@ void validateResize(Fusion* fusion) {
for (auto tv : ir_utils::filterByType<TensorView>(fusion_vals)) {
// Make sure resize is only used as part of rfactor transformations
auto rf_to_leaf_exprs = StmtSort::getExprsBetween(
fusion,
{tv->getMaybeRFactorDomain().begin(),
tv->getMaybeRFactorDomain().end()},
{tv->getLeafDomain().begin(), tv->getLeafDomain().end()});
Expand Down
5 changes: 2 additions & 3 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor {
!fusion->isA<kir::Kernel>(),
"Invalid container. Kernel container not allowed.\n");

traverseTo(fusion, fusion->getTerminatingOutputs(), false, false);
traverseTo(fusion->getTerminatingOutputs(), false, false);

finalizeDynamicVals();

Expand Down Expand Up @@ -147,7 +147,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor {
//! Process vector of leaf dynamic values by finding inputs and recording the
//! result into info_
void finalizeDynamicVals() {
const auto inputs = InputsOf::outputs(info_.fusion(), leaf_dynamic_vals_);
const auto inputs = InputsOf::outputs(leaf_dynamic_vals_);
info_.root_dynamic_vals_.insert(inputs.begin(), inputs.end());

// initial_info_ provides a set of Vals that are used for concretization.
Expand Down Expand Up @@ -621,7 +621,6 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {
// Note that it is assumed that theres's no further expression
// beyond the rfactor domain as asserted above
auto all_id_exprs = StmtSort::getExprsBetween(
tv->fusion(),
{tv->getRootDomain().begin(), tv->getRootDomain().end()},
{tv->getMaybeRFactorDomain().begin(),
tv->getMaybeRFactorDomain().end()});
Expand Down
22 changes: 7 additions & 15 deletions csrc/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ void FusionExecutor::compileFusion(
}
output_extents.emplace_back(extent);
}
auto dependencies = InputsOf::outputs(fusion, output_extents);
auto dependencies = InputsOf::outputs(output_extents);
if (std::any_of(dependencies.begin(), dependencies.end(), [](Val* val) {
return val->isFusionInput();
})) {
Expand Down Expand Up @@ -607,7 +607,6 @@ std::pair<std::vector<int64_t>, std::vector<int64_t>> inferShapeOfOutput(

class ForwardTraverseFromAllocToRFactor {
at::Tensor tensor_;
TensorView* tv_;
ExpressionEvaluator& ee_;
std::list<IterDomain*>& frontier_;

Expand Down Expand Up @@ -725,18 +724,15 @@ class ForwardTraverseFromAllocToRFactor {
public:
ForwardTraverseFromAllocToRFactor(
at::Tensor tensor,
TensorView* tv,
ExpressionEvaluator& ee,
std::list<IterDomain*>& frontier)
: tensor_(std::move(tensor)), tv_(tv), ee_(ee), frontier_(frontier) {}
: tensor_(std::move(tensor)), ee_(ee), frontier_(frontier) {}

at::Tensor run(
const std::vector<IterDomain*>& rfactor,
const std::vector<IterDomain*>& alloc) {
auto forward_exprs = StmtSort::getExprsBetween(
tv_->fusion(),
{alloc.begin(), alloc.end()},
{rfactor.begin(), rfactor.end()});
{alloc.begin(), alloc.end()}, {rfactor.begin(), rfactor.end()});
for (auto expr : forward_exprs) {
handle(expr);
}
Expand All @@ -748,7 +744,6 @@ class ForwardTraverseFromAllocToRFactor {
// transformations.
class BackwardTraverseFromAllocToRFactor {
at::Tensor tensor_;
TensorView* tv_;
ExpressionEvaluator& ee_;
std::list<IterDomain*>& frontier_;

Expand Down Expand Up @@ -853,18 +848,15 @@ class BackwardTraverseFromAllocToRFactor {
public:
BackwardTraverseFromAllocToRFactor(
at::Tensor tensor,
TensorView* tv,
ExpressionEvaluator& ee,
std::list<IterDomain*>& frontier)
: tensor_(std::move(tensor)), tv_(tv), ee_(ee), frontier_(frontier) {}
: tensor_(std::move(tensor)), ee_(ee), frontier_(frontier) {}

at::Tensor run(
const std::vector<IterDomain*>& rfactor,
const std::vector<IterDomain*>& alloc) {
auto backward_exprs = StmtSort::getExprsBetween(
tv_->fusion(),
{rfactor.begin(), rfactor.end()},
{alloc.begin(), alloc.end()});
{rfactor.begin(), rfactor.end()}, {alloc.begin(), alloc.end()});
std::reverse(backward_exprs.begin(), backward_exprs.end());
for (auto expr : backward_exprs) {
handle(expr);
Expand Down Expand Up @@ -894,9 +886,9 @@ at::Tensor transformOutputFromAllocationToRFactor(
// forward and a backward traverse.
std::list<IterDomain*> frontier(alloc.begin(), alloc.end());
NVF_ERROR(tensor.dim() == (int64_t)frontier.size());
tensor = ForwardTraverseFromAllocToRFactor(tensor, tv, ee, frontier)
tensor = ForwardTraverseFromAllocToRFactor(tensor, ee, frontier)
.run(rfactor, alloc);
tensor = BackwardTraverseFromAllocToRFactor(tensor, tv, ee, frontier)
tensor = BackwardTraverseFromAllocToRFactor(tensor, ee, frontier)
.run(rfactor, alloc);
NVF_ERROR(frontier.size() == rfactor.size());
// Now that all affine transformations are handled, and frontiers should
Expand Down
6 changes: 3 additions & 3 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ bool Fusion::isNoOp() {
}

std::vector<Val*> Fusion::inputsOf(Val* val) {
return InputsOf::output(this, val);
return InputsOf::output(val);
}

void Fusion::validateInputs() {
Expand Down Expand Up @@ -528,7 +528,7 @@ void Fusion::printMath(bool from_outputs_only) {
leaf_vals.push_back(val);
}
}
exprs_for_print = StmtSort::getExprsTo(this, leaf_vals);
exprs_for_print = StmtSort::getExprsTo(leaf_vals);
}

debug() << "\n%kernel_math {\n";
Expand Down Expand Up @@ -649,7 +649,7 @@ std::vector<Val*> Fusion::usedMathVals() {
// there can be vals that are created inside a fusion without using
// anything from inputs. See, for example, tv0 in the
// FusionOuterSplit test.
const auto inputs = InputsOf::outputs(this, outputs());
const auto inputs = InputsOf::outputs(outputs());
auto used_math_vals = DependencyCheck::getAllValsBetween(
{inputs.begin(), inputs.end()}, outputs());
// When an expre has multiple outputs and only some of them are
Expand Down
11 changes: 3 additions & 8 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2207,12 +2207,7 @@ std::optional<std::unique_ptr<SchedulerEntry>> SegmentedGroup::
}

void SegmentedGroup::resetExprList() {
auto input_group_vec = getAllInputs(this);
std::unordered_set<Val*> input_group_set(
input_group_vec.begin(), input_group_vec.end());
auto expr_set =
DependencyCheck::getAllExprsBetween(input_group_set, getAllOutputs(this));
exprs_ = std::vector<Expr*>(expr_set.begin(), expr_set.end());
exprs_ = StmtSort::getExprsBetween(getAllInputs(this), getAllOutputs(this));
}

// Custom merge node passes:
Expand Down Expand Up @@ -3703,7 +3698,7 @@ void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) {
group->input_vals = IterVisitor::getInputsTo(group->inputs());

// Grab all expressions needed to produce to_visit
auto input_exprs = StmtSort::getExprsTo(completeFusion(), to_visit);
auto input_exprs = StmtSort::getExprsTo(to_visit);

// Insert those expressions at the beginning of the group
group->exprs_.insert(
Expand Down Expand Up @@ -3963,7 +3958,7 @@ class ForceHalfAnnotation : public IterVisitor {
val->getDataType().value() == DataType::BFloat16);
});

annotation.traverseTo(fusion, fp16_outputs);
annotation.traverseTo(fp16_outputs);
return annotation.force_fp16_tv_set_;
}

Expand Down
Loading