Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
212d522
WIP
naoyam Feb 16, 2025
29c2489
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Feb 25, 2025
0e829b8
WAR for vectorization of resize
naoyam Feb 26, 2025
82d274d
Revert size changes
naoyam Feb 26, 2025
fc9452b
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Feb 26, 2025
1fffd78
fix
naoyam Feb 26, 2025
6426244
build fix
naoyam Feb 26, 2025
dbb47ca
build fix
naoyam Feb 26, 2025
bd4ad8b
cleanup
naoyam Feb 27, 2025
01c70e3
WIP
naoyam Feb 27, 2025
0368167
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Feb 27, 2025
a918efc
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Feb 28, 2025
63e4d68
test fix
naoyam Feb 28, 2025
7d4d82b
fix
naoyam Feb 28, 2025
6e265d3
build fix
naoyam Feb 28, 2025
b890e89
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Feb 28, 2025
f4f5b89
cleanup
naoyam Feb 28, 2025
386d01c
remove debug print
naoyam Feb 28, 2025
23f8b54
cleanup
naoyam Mar 1, 2025
a6b29c8
skip failing test
naoyam Mar 1, 2025
abf346f
cleanup
naoyam Mar 1, 2025
d93c12f
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Mar 1, 2025
31f0ed4
cleanup
naoyam Mar 1, 2025
bb92176
fix
naoyam Mar 1, 2025
0d15c9d
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Mar 1, 2025
357f684
Cache
naoyam Mar 1, 2025
51f6efc
comments
naoyam Mar 1, 2025
c4a0b87
cleanup
naoyam Mar 1, 2025
0fbf0f0
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Mar 1, 2025
220c5b1
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Mar 3, 2025
3c71a92
revert
naoyam Mar 3, 2025
8967676
rephrase
naoyam Mar 12, 2025
b739b41
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Mar 12, 2025
d8e3f05
comment
naoyam Mar 12, 2025
928e555
cleanup
naoyam Mar 14, 2025
999ce45
cleanup
naoyam Mar 14, 2025
546d7ea
cleanup
naoyam Mar 14, 2025
58dcb3a
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
596 changes: 596 additions & 0 deletions csrc/graph_traversal.h

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions csrc/scheduler/compile_time_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ enum class CompileTimeEntryType {
VECTORIZABLE_INPUTS_AND_OUTPUTS,
INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS,
TV_TO_CONTIG_INNER_SIZE_MAPS,
RESIZE_VECTORIZATION_FACTORS,
UNROLLABLE_INPUTS_AND_OUTPUTS,
REDUCTION_TVS,
PERSISTENT_BUFFER_INFO,
Expand Down Expand Up @@ -106,6 +107,15 @@ class TvToContigInnerSizeMaps {
CompileTimeEntryType::TV_TO_CONTIG_INNER_SIZE_MAPS;
};

//! Stores the scalar vals that a vectorization factor must be able to
//! divide evenly
class ResizeVectorizationFactors {
public:
using DataType = std::unordered_set<Val*>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::RESIZE_VECTORIZATION_FACTORS;
};

//! Entry type definition class for `INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS`,
//! stores the fusion's inputs and outputs grouped by inner most dimension.
class InputsOutputsInnerDimGroups {
Expand Down
2 changes: 2 additions & 0 deletions csrc/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ template class HeuristicDataCacheEntry<
HeuristicCompileTime::VectorizableInputsAndOutputs>;
template class HeuristicDataCacheEntry<
HeuristicCompileTime::TvToContigInnerSizeMaps>;
template class HeuristicDataCacheEntry<
HeuristicCompileTime::ResizeVectorizationFactors>;
template class HeuristicDataCacheEntry<
HeuristicCompileTime::InputsOutputsInnerDimGroups>;
template class HeuristicDataCacheEntry<
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
IdModel id_model(fusion, /*build_graphs=*/false);
const auto& broadcast_graph = id_model.buildBroadcastGraph();

auto resize_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);
auto resize_tensor_ops = scheduler_tools::getResizeBasedOps(fusion);

// Slicing of or to a broadcast ID is not allowed yet.
for (auto resize_tensor_op : resize_tensor_ops) {
Expand Down
4 changes: 4 additions & 0 deletions csrc/scheduler/tools/resize_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ bool hasResizeBasedOps(Fusion* fusion) {
return ir_utils::hasOpsOfType<SliceOp, PadOp>(fusion);
}

std::vector<Expr*> getResizeBasedOps(Fusion* fusion) {
return ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);
}

void propagateResizeToInputs(Expr* resize_tensor_op) {
NVF_ERROR(
resize_tensor_op->isA<SliceOp>() || resize_tensor_op->isA<PadOp>(),
Expand Down
2 changes: 2 additions & 0 deletions csrc/scheduler/tools/resize_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ bool isResizeBasedOp(Expr* expr);

bool hasResizeBasedOps(Fusion* fusion);

std::vector<Expr*> getResizeBasedOps(Fusion* fusion);

// For a given resize-based tensor op such as SliceOp and PadOp, make the loop
// domain of each dependent producer tensor exact-mapped by propagating
// the iter-domain ops of the output tensor of the given op. Note that
Expand Down
113 changes: 113 additions & 0 deletions csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <iter_visitor.h>
#include <scheduler/registry.h>
#include <scheduler/runtime_info.h>
#include <scheduler/tools/resize_utils.h>
#include <val_graph_visitor.h>

#include <c10/util/irange.h>
Expand Down Expand Up @@ -843,6 +844,96 @@ std::vector<std::unordered_map<TensorView*, Val*>> getTvToContigInnerSizeMapsOf(
return mappers;
}

// This is a WAR for vectorizing through resized iter domains. The
// spanning tree based analysis is not guaranteed to take all resize
// ops into considerations (issue
// https://github.com/NVIDIA/Fuser/issues/3640). To workaround the
// limitation, grab all factors that must be divisible by a
// vectorization factors.
std::unordered_set<Val*> getResizeVectorizationFactors(
TensorView* reference_tv,
int64_t break_point) {
Fusion* fusion = reference_tv->fusion();
std::unordered_set<Val*> factors;
const auto resize_based_ops = scheduler_tools::getResizeBasedOps(fusion);

if (resize_based_ops.empty()) {
return factors;
}

IdModel id_model(reference_tv->fusion());
const auto& graph = id_model.buildExactGraph();

const auto ref_groups = graph.toGroups(reference_tv->getLogicalDomain());

// For each of resize-based tensor ops, find all resize ops
// that exist between the vectorized reference IDs and the output
// tensor.
for (auto resize_based_op : resize_based_ops) {
auto resize_out = resize_based_op->output(0)->as<TensorView>();
NVF_ERROR(
resize_out->hasRoot(), "Unexpected op: ", resize_based_op->toString());
// getAllExprGroupsBetween finds exprs between IDs. To make sure
// the the resize op of this resize_based_op tensor op is found,
// use both the root and logical domains as the traversal targets.
ValGroups resize_inp_out;
resize_inp_out.pushBack(graph.toGroups(resize_out->getRootDomain()));
resize_inp_out.pushBack(graph.toGroups(resize_out->getLogicalDomain()));

auto expr_path = getAllExprGroupsBetween(
graph,
ref_groups,
resize_inp_out,
/*require_all_to_visited=*/false)
.first;

ValGroups vectorized_groups;
for (auto it = reference_tv->getLogicalDomain().begin() + break_point;
it != reference_tv->getLogicalDomain().end();
++it) {
vectorized_groups.pushBack(graph.toGroup(*it));
}

// Find all resize exprs that appear in expr_path and depend on
// vectorized_groups. Since expr_path is not guaranteed to be
// topologically sorted, need to loop through the path until
// converged.

bool something_has_changed = true;
while (something_has_changed) {
something_has_changed = false;
for (const auto& [expr_g, dir] : expr_path) {
const auto inputs = getInputsOfExprGroup(graph, expr_g, dir);
if (std::none_of(
inputs.begin(), inputs.end(), [&](const ValGroup& inp) {
return vectorized_groups.has(inp);
})) {
continue;
}

if (vectorized_groups.pushBack(
getOutputsOfExprGroup(graph, expr_g, dir))) {
something_has_changed = true;
}

auto resize = dynamic_cast<Resize*>(expr_g->front());
if (resize == nullptr) {
continue;
}

// These three vals need to be divisible
factors.emplace(resize->leftExpand());
factors.emplace(resize->rightExpand());
factors.emplace(
dir == Direction::Forward ? resize->out()->extent()
: resize->in()->extent());
}
}
}

return factors;
}

} // namespace

int64_t getVectorizationFactor(
Expand Down Expand Up @@ -881,6 +972,15 @@ int64_t getVectorizationFactor(
return 1;
}

auto resize_factors_entry =
HeuristicDataCacheEntry<HeuristicCompileTime::ResizeVectorizationFactors>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏

data_cache, [&reference_tv, &break_point]() {
return std::make_unique<std::unordered_set<Val*>>(
getResizeVectorizationFactors(reference_tv, break_point));
});

const auto& resize_factors = resize_factors_entry.get();

int64_t max_vec_size = SchedulerRuntimeInfo::max_alignment_size_in_byte;
const auto& tv_to_inner_size_map = vectorize_maps_entry.get().at(break_point);

Expand Down Expand Up @@ -920,6 +1020,19 @@ int64_t getVectorizationFactor(
max_vec_size);
}

// This is a WAR for vectorization through resize as the spanning
// tree based traversal is not guaranteed to reflect all resize ops
// that may affect vectorization. This is a safe but conservative
// analysis since it should only be necessary for innermost IDs.
for (const auto resize_factor : resize_factors) {
auto inferred_val =
runtime_info.expressionEvaluator().evaluate(resize_factor);
if (!inferred_val.hasValue()) {
return 1;
}
max_vec_size = std::gcd(max_vec_size, inferred_val.as<int64_t>());
}

return max_vec_size;
}

Expand Down
32 changes: 28 additions & 4 deletions csrc/val_graph_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <val_graph_visitor.h>

#include <graph_traversal.h>
#include <id_model/to_string.h>

#include <variant>
#include <val_graph_visitor.h>

namespace nvfuser {

Expand Down Expand Up @@ -245,4 +243,30 @@ bool isCyclic(const ValGraph& graph) {
return ValGraphCycleDetector(graph).cycle_detected_;
}

std::pair<ExprGroupPath, bool> getAllExprGroupsBetween(
const ValGraph& graph,
const ValGroups& from,
const ValGroups& to,
bool require_all_to_visited,
Direction allowed_direction) {
FindAllExprs<
ExprGroup,
ValGroup,
ValGraphDefinitions,
ValGraphUses,
ValGraphInputs,
ValGraphOutputs>
finder(
ValGraphDefinitions{graph},
ValGraphUses{graph},
ValGraphInputs{graph},
ValGraphOutputs{graph},
{from.vector().begin(), from.vector().end()},
{to.vector().begin(), to.vector().end()},
require_all_to_visited,
allowed_direction);
finder.traverseAllEdges();
return finder.getPartiallyOrderedExprs();
}

} // namespace nvfuser
11 changes: 11 additions & 0 deletions csrc/val_graph_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ struct GetValType<ExprGroup> {
using type = ValGroup;
};

using ExprGroupPath = std::vector<std::pair<ExprGroup, Direction>>;

class ValGraphBFS : public BFS<
ExprGroup,
ValGroup,
Expand Down Expand Up @@ -292,4 +294,13 @@ inline std::vector<ValGroup> getOutputsOfExprGroup(
expr, dir, ValGraphInputs(graph), ValGraphOutputs(graph));
}

// Grab all ExprGroups between to sets of ValGroups. ExprGroups are
// not guaranteed to be topologically sorted.
std::pair<ExprGroupPath, bool> getAllExprGroupsBetween(
const ValGraph& graph,
const ValGroups& from,
const ValGroups& to,
bool require_all_to_visited = true,
Direction allowed_direction = Direction::Undefined);

} // namespace nvfuser
Loading