Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp
${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp
${NVFUSER_SRCS_DIR}/val_graph.cpp
${NVFUSER_SRCS_DIR}/val_graph_visitor.cpp
)

# We don't link CUPTI for MSVC
Expand Down
8 changes: 8 additions & 0 deletions csrc/disjoint_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ class VectorOfUniqueEntries {
return vector_.end();
}

T& at(size_t pos) {
return vector_.at(pos);
}

const T& at(size_t pos) const {
return vector_.at(pos);
}

std::string toString() const {
std::stringstream ss;
ss << "{ ";
Expand Down
52 changes: 52 additions & 0 deletions csrc/val_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,58 @@ std::vector<ValGroup> ValGraph::inputGroups(const ExprGroup& expr) const {
return input_groups;
}

ValGroups ValGraph::getTerminatingInputs() const {
// Initialize vals to traverse
ValGroups all_vals{
disjointValSets().disjointSets().begin(),
disjointValSets().disjointSets().end()};

// Initialize exprs to traverse
ExprGroups all_exprs{
disjointExprSets().disjointSets().begin(),
disjointExprSets().disjointSets().end()};

// Grab all vals that are not input, i.e., having a defining expr
// within all_exprs.
//
// Note that an input Val group may be mapped with an output
// group. For example, the AlmostExact graph maps an input of split
// with the outer output if the split factor is one. Such a Val
// group is considered a terminating input as long as the input has
// no defining expression. This is for the use case of
// ValGraphVisitor.
//
// Example:
//
// [i0, i1]
// split by 1
// [i0/1, 1, i1]
// merge
// [i0/1, 1*i1]
//
// Here, i0 and i0/1 would create a Val group of {i0, i0/1} in the
// AlmostExact graph. This group has a defining expression of the
// split, but since it's a cyclic dependency, we ignore the
// expression and consider the Val group a terminating input.

ValGroups not_inputs;
for (const ExprGroup& expr_group : all_exprs) {
const std::vector<ValGroup> input_groups = inputGroups(expr_group);
const std::vector<ValGroup> output_groups = outputGroups(expr_group);
std::unordered_set<ValGroup> input_set{
input_groups.begin(), input_groups.end()};

for (const ValGroup& output_group : output_groups) {
if (input_set.count(output_group)) {
continue;
}
not_inputs.pushBack(output_group);
}
}

return all_vals.computeSubtract(not_inputs);
}

ExprGroups ValGraph::allUsesOf(const ValGroups& of) const {
DequeOfExprGroup to_visit;
for (const ValGroup& of_val_group : of) {
Expand Down
3 changes: 3 additions & 0 deletions csrc/val_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class ValGraph {
std::vector<ValGroup> outputGroups(const ExprGroup& expr) const;
std::vector<ValGroup> inputGroups(const ExprGroup& expr) const;

// Return Val groups that have no definition.
ValGroups getTerminatingInputs() const;

// Recursively traverses uses of the IdGroups in 'of' and returns all
// ExprGroups that have a use in their definition of provided of IdGroups.
ExprGroups allUsesOf(const ValGroups& of) const;
Expand Down
152 changes: 152 additions & 0 deletions csrc/val_graph_visitor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <val_graph_visitor.h>

#include <id_model/to_string.h>

namespace nvfuser {

void ValGraphVisitor::traverse() {
const ValGroups terminating_inputs = graph().getTerminatingInputs();
std::deque<ValGroup> to_visit_vals(
terminating_inputs.begin(), terminating_inputs.end());
ValGroups visited_vals;

std::deque<ExprGroup> to_visit_exprs;
ExprGroups visited_exprs;

auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool {
const auto inp_groups = graph().inputGroups(expr_group);
return std::all_of(
inp_groups.begin(), inp_groups.end(), [&](ValGroup val_group) {
return visited_vals.has(val_group) || val_group->empty();
});
};

// If any input of the def expr is mapped with the val
// group itself, i.e., a trivial expr, allow visiting the
// val group first. The trivial expr group will be visited
// after the val group.
//
// Example:
//
// [i0, 1]
// merge
// [i0*1]
// map i0 and i0*1
// ValGroups: {{i0, i0*1}, {1}}
//
// Then, {i0, i0*1} and {1} would be visited first, then the merge
// expr group would be visited. {i0, i0*1} is also an output group
// of the merge but since it's already in the visited set, it would
// not be visited again.
//
// See also IdModelTest.ValGraphStmtSort3 for a concrete example.
auto is_val_ready = [&](const ValGroup& val_group) -> bool {
if (terminating_inputs.has(val_group)) {
return true;
}
const ExprGroups& unique_defs = graph().getDefinitions(val_group);
return std::all_of(
unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) {
if (expr_group->empty() || visited_exprs.has(expr_group)) {
return true;
}
// Handle ExprGroups that return one or some of its input ValGroups as
// output. This expr_group is not visited yet, which means there're
// input ValGroups that are not yet visited. If those not-visited
// inputs are actually the same as val_group, visit val_group at this
// point to resolve the circular dependency.
for (const ValGroup& input_group : graph().inputGroups(expr_group)) {
if (input_group != val_group && !visited_vals.has(input_group) &&
input_group->empty()) {
return false;
}
}
return true;
});
};

// Detect if nothing has been processed which would put us in an infinite
// loop
bool something_was_processed = false;

do {
something_was_processed = false;

// Process expressions first as all definitions of vals have to be
// processed before we can process that val.

while (!to_visit_exprs.empty()) {
ExprGroup current_expr_group = to_visit_exprs.front();
to_visit_exprs.pop_front();
NVF_ERROR(!current_expr_group->empty());
if (visited_exprs.has(current_expr_group)) {
continue;
}

if (is_expr_ready(current_expr_group)) {
handle(current_expr_group);

something_was_processed = true;
visited_exprs.pushBack(current_expr_group);

for (const ValGroup& output_group :
graph().outputGroups(current_expr_group)) {
to_visit_vals.push_back(output_group);
}
}
}

std::deque<ValGroup> still_to_visit_vals;
while (!to_visit_vals.empty()) {
auto current_val_group = to_visit_vals.front();
to_visit_vals.pop_front();
NVF_ERROR(!current_val_group->empty());
if (visited_vals.has(current_val_group)) {
continue;
}

if (is_val_ready(current_val_group)) {
handle(current_val_group);

something_was_processed = true;
visited_vals.pushBack(current_val_group);

for (const ExprGroup& use_group : graph().getUses(current_val_group)) {
to_visit_exprs.push_back(use_group);
}
} else {
still_to_visit_vals.push_back(current_val_group);
}
}

std::swap(to_visit_vals, still_to_visit_vals);

} while (something_was_processed);

if (!to_visit_vals.empty()) {
std::stringstream ss;
ss << "The graph has an infinite loop. The following Vals should be visited but are never ready:";
for (const ValGroup& vg : to_visit_vals) {
ss << " " << nvfuser::toString(vg);
}
NVF_ERROR(false, ss.str());
}

if (!to_visit_exprs.empty()) {
std::stringstream ss;
ss << "The graph has an infinite loop. The following Exprs should be visited but are never ready:";
for (const ExprGroup& eg : to_visit_exprs) {
ss << " " << nvfuser::toString(eg);
}
NVF_ERROR(false, ss.str());
}
}

} // namespace nvfuser
121 changes: 121 additions & 0 deletions csrc/val_graph_visitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <disjoint_set.h>
#include <ir/all_nodes.h>
#include <val_graph.h>

namespace nvfuser {

// Iterates through a Val Graph in topological order, calling handle on
// all Val and all Expr groups in a forward topological order.
//
// Warning: A ValGraph is not guaranteed to be a DAG. In fact, the
// AlmostExact and Permissive graphs would have cycles with a ValGroup
// and an ExprGroup. For example:
//
// [i0, 1]
// merge
// [i0*1]
// Current ValGroups: {{i0}, {1}, {i0*1}}
// map i0 and i0*1 as they effectively have the same extent
// Final ValGroups: {{i0, i0*1}, {1}}
//
// Here, the merge expr is the user of i0 and the definition of
// i0*1. Since i0 and i0*1 are mapped, the dependency chain looks
// like:
//
// {i0, i0*1} ----> {merge} ----> {i0, i0*1}
// use def
//
// These ExprGroups are called trivial ExprGroups (see also
// ValGraph::isTrivialExprGroup).
//
// Strictly speaking, these cycles mean there's no valid topological
// order anymore. In our use cases for IdModel, however, it's likely
// sufficient to return an ordering such as:
//
// {i0, i0*1} -> {merge}
//
// I.e., we visit {i0, i0*1} first even though {merge} is technically
// a definition.
//
// Another alternative may be simply giving up when such a cycle is
// detected, which may be more preferrable as it would be less
// confusing. At this moment, this visitor is only used with graphs
// with no such cycle. Should be revisited when necessary.
//
// Warning: This is not a great iterator if there's a desire to minimize paths
// traveled to simply visit all ValGroups in order. See ExprsBetween to see how
// we might minimize paths.
class ValGraphVisitor {
public:
ValGraphVisitor() = delete;

ValGraphVisitor& operator=(const ValGraphVisitor& other) = delete;

ValGraphVisitor& operator=(ValGraphVisitor&& other) = delete;

virtual ~ValGraphVisitor() = default;

protected:
ValGraphVisitor(const ValGraph& val_graph) : val_graph_(val_graph) {}

ValGraphVisitor(const ValGraphVisitor& other) = default;

ValGraphVisitor(ValGraphVisitor&& other) = default;

virtual void handle(const ValGroup& val_group) = 0;
virtual void handle(const ExprGroup& expr_group) = 0;

void traverse();

const ValGraph& graph() {
return val_graph_;
};

private:
const ValGraph& val_graph_;
};

// Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor.
class ValGraphStmtSort : public ValGraphVisitor {
public:
ValGraphStmtSort(const ValGraph& val_graph) : ValGraphVisitor(val_graph) {
ValGraphVisitor::traverse();
}

// Return non-reference so that code like below can work
// for (auto expr_group: ValGraphStmtSort(graph).exprs())
ExprGroups exprs() const {
return sorted_exprs_;
}

ValGroups vals() const {
return sorted_vals_;
}

~ValGraphStmtSort() override = default;

protected:
using ValGraphVisitor::handle;

void handle(const ValGroup& val_group) override {
sorted_vals_.pushBack(val_group);
}

void handle(const ExprGroup& expr_group) override {
sorted_exprs_.pushBack(expr_group);
}

ExprGroups sorted_exprs_;
ValGroups sorted_vals_;
};

} // namespace nvfuser
Loading