Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 85 additions & 64 deletions csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,59 @@ std::vector<std::unordered_map<TensorView*, Val*>> getTvToContigInnerSizeMapsOf(
return mappers;
}

// Check if a traversal from vectorized reference IDs may reach the
// IDs of a resize expr without visiting the Resize expr itself. That's
// problematic for the vectorization analysis as the spanning-tree
// based analysis may miss the constraint by the Resize expr.
//
// For this analysis, we start a traversal from the vectorized
// reference IDs to both the input and output of the Resize expr but
// disallow visiting the Resize expr itself. If the traversal is still
// successful, it means there's a path from the reference IDs to the
// resize input and output IDs without visiting the Resize expr.
//
// Permissive BFS is used in this traversal as the vectorized
// reference IDs may not have all the dependencies for the
// traversal. For example, suppose there's a split resshape, and only
// the innermost ID is vectorized. The standard BFS is not able to
// move forward if only the vectorized ID is give as the backward
// split requires both outputs to be presented.
class CanSkipResize : public ValGraphPermissiveBFS {
public:
static bool run(
const ValGraph& graph,
const ValGroups& ref_groups,
Resize* resize) {
ValGroups resize_in_out_groups;
resize_in_out_groups.pushBack(graph.toGroup(resize->in()));
resize_in_out_groups.pushBack(graph.toGroup(resize->out()));
CanSkipResize bfs(graph, ref_groups, resize_in_out_groups, resize);
bfs.traverse();
return bfs.allToNodesVisited();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: here we are calling allToNodesVisited()? but the init function below has /*require_all_to_visited=*/false,, so we are returning true here as long as a single node is visited in the target, which I think is the right behavior.

But the function name is somewhat confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the traversal should continue until no further progress is made. The require_all_to_visited flag means it's considered an error if not all of the to nodes were not able to reach.

Here, we just want to check all of the to nodes are reachable. It isn't an error even if not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we are returning true here as long as a single node is visited in the target, which I think is the right behavior.

sorry I got confused myself. This returns true indicating it's safe to skip the check. So allToNodesVisited is the proper name for the function.

Thanks for elaborating on this one.

}

CanSkipResize(
const ValGraph& graph,
const ValGroups& ref_groups,
const ValGroups& resize_in_out_groups,
Resize* resize)
: ValGraphPermissiveBFS(
graph,
{ref_groups.begin(), ref_groups.end()},
{resize_in_out_groups.begin(), resize_in_out_groups.end()},
/*require_all_to_visited=*/false,
/*allowed_direction=*/Direction::Undefined),
resize_(resize) {}

bool excludeFromTraversal(const NodeType& node) const override {
const ExprGroup* e = std::get_if<ExprGroup>(&node);
return e != nullptr && (*e)->has(resize_);
}

private:
Resize* resize_ = nullptr;
};

// This is a WAR for vectorizing through resized iter domains. The
// spanning tree based analysis is not guaranteed to take all resize
// ops into considerations (issue
Expand All @@ -852,84 +905,48 @@ std::unordered_set<Val*> getResizeVectorizationFactors(
TensorView* reference_tv,
int64_t break_point) {
Fusion* fusion = reference_tv->fusion();
std::unordered_set<Val*> factors;
const auto resize_based_ops = scheduler_tools::getResizeBasedOps(fusion);

if (resize_based_ops.empty()) {
return factors;
return {};
}

IdModel id_model(reference_tv->fusion());
IdModel id_model(fusion);
const auto& graph = id_model.buildExactGraph();

const auto ref_groups = graph.toGroups(reference_tv->getLogicalDomain());
std::unordered_set<Val*> resize_factors;

// For each of resize-based tensor ops, find all resize ops
// that exist between the vectorized reference IDs and the output
// tensor.
for (auto resize_based_op : resize_based_ops) {
auto resize_out = resize_based_op->output(0)->as<TensorView>();
NVF_ERROR(
resize_out->hasRoot(), "Unexpected op: ", resize_based_op->toString());
// getAllExprGroupsBetween finds exprs between IDs. To make sure
// the the resize op of this resize_based_op tensor op is found,
// use both the root and logical domains as the traversal targets.
ValGroups resize_inp_out;
resize_inp_out.pushBack(graph.toGroups(resize_out->getRootDomain()));
resize_inp_out.pushBack(graph.toGroups(resize_out->getLogicalDomain()));

auto expr_path = getAllExprGroupsBetween(
graph,
ref_groups,
resize_inp_out,
/*require_all_to_visited=*/false)
.first;

ValGroups vectorized_groups;
for (auto it = reference_tv->getLogicalDomain().begin() + break_point;
it != reference_tv->getLogicalDomain().end();
++it) {
vectorized_groups.pushBack(graph.toGroup(*it));
auto add_resize_factors = [&](Resize* resize) {
if (!resize->leftExpand()->isZeroInt()) {
resize_factors.insert(resize->leftExpand());
}
if (!resize->rightExpand()->isZeroInt()) {
resize_factors.insert(resize->rightExpand());
}
};

// Find all resize exprs that appear in expr_path and depend on
// vectorized_groups. Since expr_path is not guaranteed to be
// topologically sorted, need to loop through the path until
// converged.

bool something_has_changed = true;
while (something_has_changed) {
something_has_changed = false;
for (const auto& [expr_g, dir] : expr_path) {
const auto inputs = getInputsOfExprGroup(graph, expr_g, dir);
if (std::none_of(
inputs.begin(), inputs.end(), [&](const ValGroup& inp) {
return vectorized_groups.has(inp);
})) {
continue;
}

if (vectorized_groups.pushBack(
getOutputsOfExprGroup(graph, expr_g, dir))) {
something_has_changed = true;
}

auto resize = dynamic_cast<Resize*>(expr_g->front());
if (resize == nullptr) {
continue;
}
const ValGroups ref_vec_groups = graph.toGroups(std::vector<Val*>{
reference_tv->getLogicalDomain().begin() + break_point,
reference_tv->getLogicalDomain().end()});

// For each of Resize exprs, if it's reachable from the reference
// vectorized IDs without visiting the Resize expr itself, its
// constraint may not be reflectd in the inner sizes.
for (auto resize : resize_based_ops) {
auto resize_out_tv = resize->output(0)->as<TensorView>();
for (const auto logical_id : resize_out_tv->getLogicalDomain()) {
auto resize = dynamic_cast<Resize*>(logical_id->definition());
if (resize == nullptr) {
continue;
}

// These three vals need to be divisible
factors.emplace(resize->leftExpand());
factors.emplace(resize->rightExpand());
factors.emplace(
dir == Direction::Forward ? resize->out()->extent()
: resize->in()->extent());
if (CanSkipResize::run(graph, ref_vec_groups, resize)) {
add_resize_factors(resize);
}
}
}

return factors;
return resize_factors;
}

} // namespace
Expand Down Expand Up @@ -1028,7 +1045,11 @@ int64_t getVectorizationFactor(
if (!inferred_val.hasValue()) {
return 1;
}
max_vec_size = std::gcd(max_vec_size, inferred_val.as<int64_t>());
auto inferred_val_int = inferred_val.as<int64_t>();
if (inferred_val_int == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for dynamic resize extents that would be 0?

continue;
}
max_vec_size = std::gcd(max_vec_size, inferred_val_int);
}

return max_vec_size;
Expand Down
91 changes: 90 additions & 1 deletion tests/cpp/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5970,7 +5970,7 @@ TEST_F(ResizeTest, AvoidCachingSliceInput) {
}
}

TEST_F(ResizeTest, VectorizeSliceMultiplePaths) {
TEST_F(ResizeTest, VectorizeInnerSliceMultiplePaths) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());
Expand Down Expand Up @@ -6005,6 +6005,50 @@ TEST_F(ResizeTest, VectorizeSliceMultiplePaths) {
EXPECT_EQ(tv6->getLoopDomain().back()->extent()->evaluate(), 2);
}

// The current analysis is not precise enough to pass this test
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test, tv0 is resized in two different ways. The spanning-tree based analysis is not guaranteed to correctly identify the vectorization constraint.

The WAR when applied to this case is still too conservative.

TEST_F(ResizeTest, DISABLED_VectorizeOuterSliceMultiplePaths) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

const std::vector<int64_t> shape{4, 1024 * 1024};

auto tv0 = makeContigConcreteTensor(shape);
fusion.addInput(tv0);

auto tv1 =
pad(tv0,
{fusion.zeroVal(),
fusion.zeroVal(),
IrBuilder::create<Val>(2),
IrBuilder::create<Val>(2)});
auto tv2 =
pad(tv0,
{fusion.zeroVal(),
fusion.zeroVal(),
fusion.zeroVal(),
IrBuilder::create<Val>(4)});
auto tv3 = add(tv1, tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn(shape, options);

auto outputs = scheduleAndRun(&fusion, SchedulerType::PointWise, {t0});
testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__);

// While there's a pad with factor of 2, it shouldn't matter as the
// inner ID is large enough.
auto out_tv = tv3;
auto vec_id_it =
std::ranges::find_if(out_tv->getLoopDomain(), [](IterDomain* loop_id) {
return loop_id->getParallelType() == ParallelType::Vectorize;
});
ASSERT_NE(vec_id_it, out_tv->getLoopDomain().end())
<< "Vectorized ID not found: " << out_tv->toString();
EXPECT_EQ((*vec_id_it)->extent()->evaluate(), 4);
}

// Repro of issue #4202
TEST_F(ResizeTest, PropagateResizeThroughMultiplePaths) {
auto fusion_ptr = std::make_unique<Fusion>();
Expand Down Expand Up @@ -6040,4 +6084,49 @@ TEST_F(ResizeTest, PropagateResizeThroughMultiplePaths) {
testValidate(&fusion, outputs.outputs, {t0, t1}, __LINE__, __FILE__);
}

// Check if vectorization is properly applied even when a resized ID
// is reachable from vectorized IDs. Pattern extracted from Litgpt
// LLama RoPE backward.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick on comment.

This is the case where it's safe to skip the additional resize check. So that means the resized ID is NOT reachable from vectorized IDs.

TEST_F(ResizeTest, VectorizeOuterPad) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

const std::vector<int64_t> shape1{1, 8, 4, 8192, 128};
const std::vector<int64_t> shape2{1, 8, 1, 8192, 128};
auto tv0 = makeContigConcreteTensor(shape1, DataType::BFloat16);
fusion.addInput(tv0);
auto tv1 = makeContigConcreteTensor(shape2, DataType::BFloat16);
fusion.addInput(tv1);
auto tv2 = makeContigConcreteTensor(shape2, DataType::BFloat16);
fusion.addInput(tv2);

// [1, 8, 6, 8192, 128]
auto tv3 = cat({tv0, tv1, tv2}, 2);
// [1, 8192, 8, 6, 128]
auto tv4 = permute(tv3, {0, 3, 1, 2, 4});
auto tv5 = reshape(tv4, {1, 8192, 8, 6, 128}, {1, 8192, 6144});
fusion.addOutput(tv5);

auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0);
auto t0 = at::randn(shape1, options);
auto t1 = at::randn(shape2, options);
auto t2 = at::randn(shape2, options);

auto outputs =
scheduleAndRun(&fusion, SchedulerType::PointWise, {t0, t1, t2});
testValidate(&fusion, outputs.outputs, {t0, t1, t2}, __LINE__, __FILE__);

auto out_tv = tv5;
// While there's a pad with factor of 2, it shouldn't matter as the
// inner ID is large enough.
auto vec_id_it =
std::ranges::find_if(out_tv->getLoopDomain(), [](IterDomain* loop_id) {
return loop_id->getParallelType() == ParallelType::Vectorize;
});
ASSERT_NE(vec_id_it, out_tv->getLoopDomain().end())
<< "Vectorized ID not found: " << out_tv->toString();
EXPECT_EQ((*vec_id_it)->extent()->evaluate(), 8);
}

} // namespace nvfuser