Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,18 +234,13 @@ std::unique_ptr<HeuristicParams> ResizeScheduler::computeHeuristics(
});
TensorView* ref_tv = ref_tv_entry.get()[0];

// Before applying the vectorization split, any reshape transform of
// the largest input will be cancelled whenever possible, so the
// largest input is used as the reference of vectorization.
auto vec_ref_tv = largest_input != nullptr ? largest_input : ref_tv;

// Only consider the innermost dimension to vectorize for now.
// TODO: Consider vectorizing merged IDs, not just the innermost
params->vectorization_factor = vectorize_helper::getVectorizationFactor(
runtime_info,
vec_ref_tv,
ref_tv,
data_cache,
(int64_t)vec_ref_tv->getLogicalDomain().size() - 1,
(int64_t)ref_tv->getLogicalDomain().size() - 1,
{});

return params;
Expand Down Expand Up @@ -300,7 +295,8 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
// The tensors are going to be reordered to align with the largest
// input. To make it work, merge operations for reshape should be
// cancelled.
scheduler_tools::cancelReshapeInLoopDomains(largest_input);
scheduler_tools::cancelReshapeInLoopDomains(
largest_input, /*skip_innermost_id=*/true);
}

for (auto expr : fusion->exprs()) {
Expand Down
27 changes: 25 additions & 2 deletions csrc/scheduler/tools/loop_domain_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class LoopDomainSchedulerReplayTransform : OptInConstDispatch {
const std::vector<IterDomain*>& output_ids_;
};

// Replay a given IterDomain transform expression on the loop domain
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, but this is a comment I had for a past PR but forgot to add.

// of a given tensor using specified loop IDs as its inputs.
class ReplayForwardTransformOnLoopDomain : OptInConstDispatch {
public:
static void replayAs(
Expand Down Expand Up @@ -541,7 +543,11 @@ void scheduleLoopDomainsBy(

// When the direction is forward, the TensorView transform
// APIs, e.g., TensorView::split, can be used, which doesn't need
// to use TensorView::setLoopDomain.
// to use TensorView::setLoopDomain. This is important as
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, but this is a comment I had for a past PR but forgot to add.

// setLoopDomain may result in losing extra IDs added by prior
// scheduleLoopDomain calls, which was indeed the case with the
// Llama 3 RoPE backward (see also
// https://github.com/NVIDIA/Fuser/issues/3571).
if (replay_dir_tv == Direction::Forward) {
ReplayForwardTransformOnLoopDomain::replayAs(tv, input_ids, transform);
continue;
Expand Down Expand Up @@ -582,7 +588,7 @@ void scheduleLoopDomainsBy(
return;
}

void cancelReshapeInLoopDomains(TensorView* from_tv) {
void cancelReshapeInLoopDomains(TensorView* from_tv, bool skip_innermost_id) {
Fusion* fusion = from_tv->fusion();
IdModel id_model(fusion, /*build_graphs=*/false);
id_model.buildExactGraph();
Expand Down Expand Up @@ -677,13 +683,30 @@ void cancelReshapeInLoopDomains(TensorView* from_tv) {
{reshape_out->getLogicalDomain().begin(),
reshape_out->getLogicalDomain().end()});

std::unordered_set<Expr*> reshape_exprs_with_innermost_logical_id_set;
if (skip_innermost_id) {
auto reshape_exprs_with_innermost_logical_id =
DependencyCheck::getAllExprsBetween(
{reshape_out->getRootDomain().begin(),
reshape_out->getRootDomain().end()},
{reshape_out->getLogicalDomain().back()});
reshape_exprs_with_innermost_logical_id_set = {
reshape_exprs_with_innermost_logical_id.begin(),
reshape_exprs_with_innermost_logical_id.end()};
}

auto reshape_out_loop_domain = reshape_out->getLoopDomain();

for (auto reshape_exprs_it = reshape_exprs.rbegin();
reshape_exprs_it != reshape_exprs.rend();
++reshape_exprs_it) {
auto reshape_expr = *reshape_exprs_it;

if (skip_innermost_id &&
reshape_exprs_with_innermost_logical_id_set.count(reshape_expr)) {
continue;
}

// If any of the output IDs of reshape_expr is not found in
// cancellable_ids, that means the expr cannot be cancelled.
if (std::any_of(
Expand Down
8 changes: 7 additions & 1 deletion csrc/scheduler/tools/loop_domain_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,13 @@ void scheduleLoopDomainsBy(
// iter domain is reduced, the split needs to remain. If a reshape
// only consists of merge transforms, cancellation should be possible,
// but that is not currently supported.
void cancelReshapeInLoopDomains(TensorView* from_tv);
//
// When the skip_innermost_id flag is true, any reshape that involves
// innermost logical ID is not canceled even when it's technically
// possible. This is a WAR for the resize scheduler.
void cancelReshapeInLoopDomains(
TensorView* from_tv,
bool skip_innermost_id = false);

} // namespace scheduler_tools
} // namespace nvfuser
84 changes: 82 additions & 2 deletions tests/cpp/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4755,7 +4755,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCat) {
Fusion& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

std::vector<int64_t> shape({-1, 100});
std::vector<int64_t> shape({-1, 128});
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a WAR for the other vectorization issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the other vectorization issue referring to #3640 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes


EnableOptionsGuard enable_options_guard;
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
Expand All @@ -4781,7 +4781,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCat) {
auto tv5 = cat({tv4, tv2}, 1);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({16, 100}, options);
auto t0 = at::randn({16, 128}, options);

fusion.addOutput(tv5);

Expand Down Expand Up @@ -5822,4 +5822,84 @@ TEST_F(ResizeTest, DoNotFuseResizeAndIndexOps) {
}
}

// Split-based reshape followed by a slice. The reshape is not
// cancelable. The vectorization factor based on the innermost logical
// ID of the input is not a valid factor as the fusion is scheduled
// based on the post-reshape shape.
TEST_F(ResizeTest, VectorizeInnermostWithReshapeSplit) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

std::vector<int64_t> shape1{128L * 16L};
std::vector<int64_t> shape2{shape1[0] / 2L, 2L};

auto tv0 = makeContigConcreteTensor(shape1);
fusion.addInput(tv0);

auto tv1 = sin(tv0);
auto tv2 = reshape(tv1, shape1, shape2);
auto tv3 = slice(
tv2,
{{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(2L)},
{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(shape2[1])}});
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn(shape1, options);

auto outputs = scheduleAndRun(&fusion, SchedulerType::Resize, {t0});
testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__);

// Should be vector by a factor of 2 because the resize scheduler
// only uses the innermost logical ID, and the extent of the output
// tensor is just 2. Before PR #3955, the resize scheduler
// attempted to vectorize by 4. Note that the slice op itself does
// not matter for the vectorization as the sliced ID is not involved
// in the vectorization.
EXPECT_EQ(
tv3->getLoopDomain().back()->getParallelType(), ParallelType::Vectorize);
EXPECT_EQ(tv3->getLoopDomain().back()->extent()->evaluate(), 2);
}

// Merge-based reshape followed by a slice. The reshape is
// cancelable. If the output is used as the reference but the reshape
// is canceled, the valid vectorization factor should be 2. The WAR of
// PR #3955 gives up canceling any reshape that involves innermost
// logical IDs to avoid this inconsistency.
TEST_F(ResizeTest, VectorizeInnermostWithReshapeMerge) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

std::vector<int64_t> shape2{16, 128L * 16L};
std::vector<int64_t> shape1{16, shape2[1] / 2L, 2L};

auto tv0 = makeContigConcreteTensor(shape1);
fusion.addInput(tv0);

auto tv1 = sin(tv0);
// [16, 128 * 16 / 2, 2] -> [16, 128 * 16]. Cancelable reshape.
auto tv2 = reshape(tv1, shape1, shape2);
auto tv3 = slice(
tv2,
{{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(2L)},
{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(shape2[1])}});
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn(shape1, options);

auto outputs = scheduleAndRun(&fusion, SchedulerType::Resize, {t0});
testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__);

// Should be vector by a factor of 4. If the reshape were canceled,
// it should have been 2, but in this case since it involves the
// innermost logical ID of tv2, it is not canceled, thus
// vectorization by 4 should be chosen.
EXPECT_EQ(
tv3->getLoopDomain().back()->getParallelType(), ParallelType::Vectorize);
EXPECT_EQ(tv3->getLoopDomain().back()->extent()->evaluate(), 4);
}

} // namespace nvfuser