Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ void IterDomainGraph::build(Fusion* fusion) {

// Grab all the rfactor ids.
for (auto consumer_tv : all_consumer_tvs) {
auto exprs = StmtSort::getExprs(
auto exprs = StmtSort::getExprsTo(
fusion,
{consumer_tv->getMaybeRFactorDomain().begin(),
consumer_tv->getMaybeRFactorDomain().end()});
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/analysis/divisible_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ std::unordered_set<Split*> getAllDivisibleSplits(
// Take the view transformations and add all the splits. Those splits are
// the only divisible splits.
auto view_exprs =
StmtSort::getExprs(fusion, {rfactor_dom.begin(), rfactor_dom.end()});
StmtSort::getExprsTo(fusion, {rfactor_dom.begin(), rfactor_dom.end()});
auto split_exprs = ir_utils::filterByType<Split>(view_exprs);
all_divisible_splits.insert(split_exprs.begin(), split_exprs.end());
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class AllocationInserter : public kir::ExprMutator {
[](IterDomain* dom) { return dom->as<Val>(); });

// Get all exprs involved in generating the allocation IDs
auto exprs = StmtSort::getExprs(tv->fusion(), start_vals);
auto exprs = StmtSort::getExprsTo(tv->fusion(), start_vals);

// Get the halo extent if found
auto getExtent = [this](IterDomain* id) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ void validatePartialSplit(Fusion* fusion) {
auto range_info = getLiveRangeOffsets(fusion);

for (auto tv : ir_utils::allTvs(fusion)) {
auto exprs = StmtSort::getExprs(
auto exprs = StmtSort::getExprsTo(
tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()});
for (auto split : ir_utils::filterByType<Split>(exprs)) {
// When the start and stop offsets are not zero, make sure the
Expand Down
6 changes: 5 additions & 1 deletion csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,11 @@ void DynamicTransformConcretizer::concretize() {
concretizeEmptyExtents();

// Finally, propagate concretized domains
auto all_stmts = StmtSort::getStmts(info_->fusion());
auto all_stmts = StmtSort::getStmts(
info_->fusion(),
/*traverse_members*/ false,
/*traverse_attributes*/ false,
/*traverse_siblings*/ true);
for (auto tv : ir_utils::filterByType<TensorView>(all_stmts)) {
mutate(tv);
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ void Fusion::printMath(bool from_outputs_only) {
leaf_vals.push_back(val);
}
}
exprs_for_print = StmtSort::getExprs(this, leaf_vals);
exprs_for_print = StmtSort::getExprsTo(this, leaf_vals);
}

debug() << "\n%kernel_math {\n";
Expand Down
2 changes: 1 addition & 1 deletion csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3600,7 +3600,7 @@ void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) {
group->input_vals = IterVisitor::getInputsTo(group->inputs());

// Grab all expressions needed to produce to_visit
auto input_exprs = StmtSort::getExprs(completeFusion(), to_visit);
auto input_exprs = StmtSort::getExprsTo(completeFusion(), to_visit);

// Insert those expressions at the beginning of the group
group->exprs_.insert(
Expand Down
2 changes: 1 addition & 1 deletion csrc/ir/cloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Statement* RecomputeTv::handle(const TensorDomain* td) {
// Make sure to recompute the history of the iteration domains, explicitly go
// through the expressions and send them to IrCloner.
auto exprs =
StmtSort::getExprs(fusion_, {td->leaf().begin(), td->leaf().end()});
StmtSort::getExprsTo(fusion_, {td->leaf().begin(), td->leaf().end()});

for (auto expr : exprs) {
IrCloner::handle(expr);
Expand Down
4 changes: 2 additions & 2 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ class ValReplacementMutator : private OptOutMutator {
// typically not used by anything else. If we don't grab that count, then it
// would be a tensorview that doesn't get updated extents. Therefore, first
// grab all leaves towards outputs and grab stmts from there.
auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true, true);
auto stmts = StmtSort::getStmtsTo(fusion, allLeafOuts(fusion), true, true);

// Some fusions, such as standalone rand_like, can have disconnected DAG, so
// we need some mechanism to make sure our replacement set is as complete as
Expand All @@ -481,7 +481,7 @@ class ValReplacementMutator : private OptOutMutator {
more.emplace_back(v);
}
}
auto more_stmts = StmtSort::getStmts(fusion, more, true, true);
auto more_stmts = StmtSort::getStmtsTo(fusion, more, true, true);
more_stmts.insert(more_stmts.end(), stmts.begin(), stmts.end());

for (auto stmt : more_stmts) {
Expand Down
106 changes: 83 additions & 23 deletions csrc/iter_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,27 @@ void IterVisitor::traverseBetween(
const std::vector<Val*>& to,
bool traverse_all_paths,
bool traverse_into_members,
bool traverse_attributes) {
bool traverse_attributes,
bool traverse_siblings) {
FusionGuard fg(fusion);

std::unordered_set<Statement*> visited;
std::unordered_set<Statement*> nodes_on_path;
std::vector<Statement*> maybe_orphaned_sibs;

stmt_stack.clear();
stmt_stack.emplace_back(to.rbegin(), to.rend());

if (traverse_siblings) {
// Append siblings of entries in "to" to bottom of stack
auto& bottom_stack = stmt_stack.back();
for (auto val : ir_utils::filterByType<Val>(bottom_stack)) {
for (auto sib : ir_utils::siblingValsOf(val)) {
maybe_orphaned_sibs.push_back(sib);
}
}
}

bool all_inputs_visited = false;

while (!stmt_stack.empty()) {
Expand Down Expand Up @@ -222,6 +234,18 @@ void IterVisitor::traverseBetween(
// If we don't want to retraverse, remove nodes we already visisted.
remove_visited(next_stmts, visited);
}

if (traverse_siblings) {
// Add unvisited siblings to next_stmts
for (auto next_val : ir_utils::filterByType<Val>(next_stmts)) {
for (auto sib : ir_utils::siblingValsOf(next_val)) {
if (traverse_all_paths || visited.find(sib) == visited.end()) {
maybe_orphaned_sibs.push_back(sib);
}
}
}
}

if (next_stmts.empty()) {
// If there's nothing to visit because it was all already visited, mark
// to process
Expand Down Expand Up @@ -251,21 +275,31 @@ void IterVisitor::traverseBetween(
}
}
}
// Handle any sibling Vals that have not yet been handled
// If traverse_siblings is false, this vector will be empty
for (auto val : maybe_orphaned_sibs) {
if (visited.find(val) == visited.end()) {
visited.insert(val);
handle(val);
}
}
}

void IterVisitor::traverseTo(
Fusion* fusion,
const std::vector<Val*>& to,
bool traverse_all_paths,
bool traverse_into_members,
bool traverse_attributes) {
bool traverse_attributes,
bool traverse_siblings) {
traverseBetween(
fusion,
{},
to,
traverse_all_paths,
traverse_into_members,
traverse_attributes);
traverse_attributes,
traverse_siblings);
}

void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) {
Expand Down Expand Up @@ -430,7 +464,7 @@ void BackwardVisitor::traverseTo(
}

auto vals = AllVals::get(fusion, from);
auto exprs = StmtSort::getExprs(fusion, from);
auto exprs = StmtSort::getExprsTo(fusion, from);

{
size_t pos = 0;
Expand Down Expand Up @@ -841,19 +875,25 @@ void StmtSort::handle(Statement* stmt) {
std::vector<Expr*> StmtSort::getExprs(
Fusion* fusion,
bool traverse_members,
bool traverse_attributes) {
bool traverse_attributes,
bool traverse_siblings) {
auto terminating_outputs = fusion->getTerminatingOutputs();
return StmtSort::getExprs(
fusion, terminating_outputs, traverse_members, traverse_attributes);
return StmtSort::getExprsTo(
fusion,
terminating_outputs,
traverse_members,
traverse_attributes,
traverse_siblings);
}

std::vector<Expr*> StmtSort::getExprs(
std::vector<Expr*> StmtSort::getExprsTo(
Fusion* fusion,
const std::vector<Val*>& to,
bool traverse_members,
bool traverse_attributes) {
auto stmts =
StmtSort::getStmts(fusion, to, traverse_members, traverse_attributes);
bool traverse_attributes,
bool traverse_siblings) {
auto stmts = StmtSort::getStmtsTo(
fusion, to, traverse_members, traverse_attributes, traverse_siblings);
auto filter = ir_utils::filterByType<Expr>(stmts.begin(), stmts.end());
std::vector<Expr*> exprs(filter.begin(), filter.end());
return exprs;
Expand All @@ -864,9 +904,15 @@ std::vector<Expr*> StmtSort::getExprsBetween(
const std::vector<Val*>& from,
const std::vector<Val*>& to,
bool traverse_members,
bool traverse_attributes) {
bool traverse_attributes,
bool traverse_siblings) {
auto stmts = StmtSort::getStmtsBetween(
fusion, from, to, traverse_members, traverse_attributes);
fusion,
from,
to,
traverse_members,
traverse_attributes,
traverse_siblings);
auto filter = ir_utils::filterByType<Expr>(stmts.begin(), stmts.end());
std::vector<Expr*> exprs(filter.begin(), filter.end());
return exprs;
Expand All @@ -875,19 +921,31 @@ std::vector<Expr*> StmtSort::getExprsBetween(
std::vector<Statement*> StmtSort::getStmts(
Fusion* fusion,
bool traverse_members,
bool traverse_attributes) {
bool traverse_attributes,
bool traverse_siblings) {
auto terminating_outputs = fusion->getTerminatingOutputs();
return StmtSort::getStmts(
fusion, terminating_outputs, traverse_members, traverse_attributes);
return StmtSort::getStmtsTo(
fusion,
terminating_outputs,
traverse_members,
traverse_attributes,
traverse_siblings);
}

std::vector<Statement*> StmtSort::getStmts(
std::vector<Statement*> StmtSort::getStmtsTo(
Fusion* fusion,
const std::vector<Val*>& to,
bool traverse_members,
bool traverse_attributes) {
bool traverse_attributes,
bool traverse_siblings) {
StmtSort es;
es.traverseTo(fusion, to, false, traverse_members, traverse_attributes);
es.traverseTo(
fusion,
to,
false,
traverse_members,
traverse_attributes,
traverse_siblings);
return es.stmts;
}

Expand All @@ -896,15 +954,17 @@ std::vector<Statement*> StmtSort::getStmtsBetween(
const std::vector<Val*>& from,
const std::vector<Val*>& to,
bool traverse_members,
bool traverse_attributes) {
bool traverse_attributes,
bool traverse_siblings) {
StmtSort es;
es.traverseBetween(
fusion,
{from.begin(), from.end()},
to,
false,
traverse_members,
traverse_attributes);
traverse_attributes,
traverse_siblings);
return es.stmts;
}

Expand Down Expand Up @@ -932,11 +992,11 @@ std::vector<Val*> InputsOf::outputs(
bool DeadCodeRemover::run() {
// First we build a set of all live Statements so that we can detect dead
// branches.
for (auto stmt : StmtSort::getStmts(fusion_, fusion_->outputs())) {
for (auto stmt : StmtSort::getStmtsTo(fusion_, fusion_->outputs())) {
markLive(stmt);
}

// Note that StmtSort::getStmts() is also run in traverseTo. In the future,
// Note that StmtSort::getStmtsTo() is also run in traverseTo. In the future,
// we could potentially refactor this so that derived classes from
// BackwardVisitor can make use of that traversal instead of repeating it.
traverseTo(fusion_, fusion_->outputs(), false);
Expand Down
Loading