From d90c7adbbf431e838b5fb48dfe3f13ed7e41550b Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sat, 17 Feb 2024 23:47:40 +0000 Subject: [PATCH 01/13] Tests for reshape. --- test/test_move_split_cat.cpp | 108 +++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/test/test_move_split_cat.cpp b/test/test_move_split_cat.cpp index 8f8bd6a16ff..8d2469d436d 100644 --- a/test/test_move_split_cat.cpp +++ b/test/test_move_split_cat.cpp @@ -299,4 +299,112 @@ TEST_F(MoveSplitCatTest, Noncancellable_UnsupportedOps) { EXPECT_FALSE(out_tensors[0].is_alias_of(in_tensor)); } +TEST_F(MoveSplitCatTest, Cancellable_ReshapeInBetween) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({4, 10}); + TensorView* s0 = slice(in, {0, 0}, {4, 2}); + TensorView* s1 = slice(in, {0, 2}, {4, 5}); + TensorView* s2 = slice(in, {0, 5}, {4, 10}); + s0 = reshape(s0, {4, 2}, {2, 2, 2}); + s1 = reshape(s1, {4, 3}, {2, 2, 3}); + s2 = reshape(s2, {4, 5}, {2, 2, 5}); + TensorView* out = cat({s0, s1, s2}, /*dim=*/-1); + + fusion->addInput(in); + fusion->addOutput(out); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({4, 10}, options); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); +} + +TEST_F(MoveSplitCatTest, Cancellable_ReshapeAndPermuteInBetween) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({6, 10}); + TensorView* s0 = slice(in, {0, 0}, {6, 2}); + TensorView* s1 = slice(in, {0, 2}, {6, 5}); + TensorView* s2 = slice(in, {0, 5}, {6, 10}); + s0 = reshape(s0, {6, 2}, {2, 3, 2}); + s1 = reshape(s1, {6, 3}, {2, 3, 3}); + s2 = reshape(s2, {6, 5}, {2, 3, 5}); + s0 = permute(s0, {1, 0, 2}); + s1 = permute(s1, {1, 0, 2}); + s2 = permute(s2, {1, 0, 2}); + TensorView* out = cat({s0, s1, s2}, /*dim=*/-1); + + fusion->addInput(in); + fusion->addOutput(out); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({6, 10}, options); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); +} + +TEST_F(MoveSplitCatTest, Cancellable_Issue1768) { + constexpr int b = 16; // batch size + constexpr int h = 12; // number of heads + constexpr int s = 128; // sequence length + constexpr int f = 64; // feature size per head + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* sdpa_backward_out = + makeContigConcreteTensor({b, h * 3, s, f}, DataType::BFloat16); + sdpa_backward_out->setAllocationDomain( + {sdpa_backward_out->axis(0), + sdpa_backward_out->axis(2), + sdpa_backward_out->axis(1), + sdpa_backward_out->axis(3)}, + true); + TensorView* dq = slice(sdpa_backward_out, {0, 0, 0, 0}, {b, h, s, f}); + TensorView* dk = slice(sdpa_backward_out, {0, h, 0, 0}, {b, h * 2, s, f}); + TensorView* dv = slice(sdpa_backward_out, {0, h * 2, 0, 0}, {b, h * 3, s, f}); + // Swap the head dimension and the sequence length dimension. + dq = permute(dq, {0, 2, 1, 3}); + dk = permute(dk, {0, 2, 1, 3}); + dv = permute(dv, {0, 2, 1, 3}); + dq = reshape(dq, {b, s, h, f}, {b, s, h * f}); + dk = reshape(dk, {b, s, h, f}, {b, s, h * f}); + dv = reshape(dv, {b, s, h, f}, {b, s, h * f}); + TensorView* cat_out = cat({dq, dk, dv}, /*dim=*/-1); + TensorView* sum_out = castOp(DataType::Float, cat_out); + sum_out = sum(sum_out, {0, 1}); + sum_out = castOp(DataType::BFloat16, sum_out); + TensorView* view_out = + reshape(cat_out, {b, s, h * f * 3}, {b * s, h * f * 3}); + TensorView* permute_out = permute(view_out, {1, 0}); + + fusion->addInput(sdpa_backward_out); + fusion->addOutput(sum_out); + fusion->addOutput(view_out); + fusion->addOutput(permute_out); + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + at::Tensor in_tensor = + at::randn({b * h * 3 * s * f}, options) + .as_strided({b, h * 3, s, f}, {h * 3 * s * f, f, h * 3 * f, 1}); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_TRUE(out_tensors[1].is_alias_of(in_tensor)); + EXPECT_TRUE(out_tensors[2].is_alias_of(in_tensor)); +} + } // namespace nvfuser From 828cddfcee81192968e395abf39c474d50269058 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 19 Feb 2024 06:07:21 +0000 Subject: [PATCH 02/13] Use IdModel to determine mergeability and propagate axis. --- csrc/preseg_passes/move_split_cat.cpp | 190 +++++++++++++++++++------- 1 file changed, 144 insertions(+), 46 deletions(-) diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index ad92030f4ae..e53ef1933f8 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -26,18 +26,30 @@ class CancelSplitCat { public: CancelSplitCat(Fusion* fusion) : fusion_(fusion), - id_model_(fusion, /*build_graphs=*/true, /*allow_self_mapping=*/true) {} + id_model_for_merging_( + fusion, + /*build_graphs=*/true, + /*allow_self_mapping=*/true), + id_model_for_propagation_( + fusion, + /*build_graphs=*/true, + /*allow_self_mapping=*/true) {} // Finds all cancellable pairs, cancels them and horizontallly // merges ops in between. void run(); private: - // Returns true when Exprs in the frontier can be horizontally merged and - // applied on the unsplit tensor. + // Returns true when Exprs between `slices` and `pads` can be horizontally + // merged and applied on the input of the split. bool horizontallyMergeable( - const std::vector& frontier, - int64_t& split_axis); + const std::vector& slices, + const std::vector& pads); + + int64_t propagateCatAxis( + const std::vector& source, + const std::vector& destination, + int64_t cat_axis); // Finds the canceling split of `cat` and returns the input TensorView of the // split. A split (implemented as multiple `slice`s) and a cat cancel when @@ -64,57 +76,70 @@ class CancelSplitCat { Fusion* fusion_; - IdModel id_model_; + // TODO(wujingyue): keep two `IdGraph`s not two `IdModel`s. An `IdModel` + // contains multiple graphs and we only care about the exact graph in it. + IdModel id_model_for_merging_; + IdModel id_model_for_propagation_; }; +bool sameOp(const std::vector& frontier) { + return std::adjacent_find( + frontier.begin(), frontier.end(), [](Expr* lhs, Expr* rhs) { + return !lhs->sameOp(rhs); + }) == frontier.end(); +} + bool CancelSplitCat::horizontallyMergeable( - const std::vector& frontier, - int64_t& split_axis) { - NVF_ERROR(!frontier.empty()); - - // Check all Exprs in `frontier` - // 1. have the same op type and attributes, - // 2. transform IDs in the same way, and - // 3. don't resize the split axis. - - if (std::adjacent_find( - frontier.begin(), frontier.end(), [](Expr* lhs, Expr* rhs) { - return !lhs->sameOp(rhs); - }) != frontier.end()) { - return false; + const std::vector& slices, + const std::vector& pads) { + NVF_ERROR(slices.size() == pads.size()); + NVF_ERROR(!slices.empty()); + + // FIXME: make it a class member. + ValGraph& exact_graph = id_model_for_merging_.idGraph(IdMappingMode::EXACT); + { + const std::vector& first_rfactor = + slices[0]->output(0)->as()->getMaybeRFactorDomain(); + size_t num_dims = first_rfactor.size(); + for (size_t i = 1; i < slices.size(); i++) { + const std::vector& rfactor = + slices[i]->output(0)->as()->getMaybeRFactorDomain(); + if (rfactor.size() != num_dims) { + return false; + } + for (size_t j = 0; j < num_dims; j++) { + exact_graph.mapVals(first_rfactor[j], rfactor[j]); + } + } } - if (auto* set = dynamic_cast(frontier[0])) { - if (set->opType() == LoadStoreOpType::Set) { - auto* set_out = set->out()->as(); - std::optional> permutation = - ir_utils::computePermutation( - set_out->getRootDomain(), set_out->getMaybeRFactorDomain()); - if (!permutation.has_value()) { + for (PadOp* pad : pads) { + auto* pad_out = pad->out()->as(); + if (id_model_for_merging_.hasSelfMapping(pad_out)) { + return false; + } + } + + { + const std::vector& first_root = + pads[0]->out()->as()->getRootDomain(); + size_t num_dims = first_root.size(); + for (size_t i = 1; i < pads.size(); i++) { + const std::vector& root = + pads[i]->out()->as()->getRootDomain(); + if (root.size() != num_dims) { return false; } - - for (size_t i = 1; i < frontier.size(); i++) { - auto* other_set_out = - frontier[i]->as()->out()->as(); - std::optional> other_permutation = - ir_utils::computePermutation( - other_set_out->getRootDomain(), - other_set_out->getMaybeRFactorDomain()); - if (!other_permutation.has_value()) { - return false; - } - if (*permutation != *other_permutation) { + for (size_t j = 0; j < num_dims; j++) { + if (!exact_graph.disjointValSets().strictAreMapped( + first_root[j], root[j])) { return false; } } - - split_axis = (*permutation)[split_axis]; - return true; } } - return false; + return true; } // If `exprs` are `SliceOp`s that form a split, returns the base tensor of the @@ -198,29 +223,89 @@ TensorView* exprsFormSplit( return split_in; } +int64_t CancelSplitCat::propagateCatAxis( + const std::vector& source, + const std::vector& destination, + int64_t cat_axis) { + ValGraph& exact_graph = + id_model_for_propagation_.idGraph(IdMappingMode::EXACT); + ValGroup cat_dim = exact_graph.toGroup(destination[cat_axis]); + while ( + std::none_of(source.begin(), source.end(), [&](IterDomain* source_dim) { + return exact_graph.toGroup(source_dim) == cat_dim; + })) { + const ExprGroups& defining_groups = exact_graph.getDefinitions(cat_dim); + if (defining_groups.size() != 1) { + return -1; + } + ExprGroup defining_group = defining_groups.front(); + Expr* def = defining_group->front(); + // FIXME: make this a function so we can early return. + if (Split* split = dynamic_cast(def)) { + if (exact_graph.toGroup(split->outer()) == cat_dim) { + cat_dim = exact_graph.toGroup(split->in()); + } else { + return -1; + } + } else if (Merge* merge = dynamic_cast(def)) { + cat_dim = exact_graph.toGroup(merge->outer()); + } else { + return -1; + } + } + + cat_axis = std::find_if( + source.begin(), + source.end(), + [&](IterDomain* source_dim) { + return exact_graph.toGroup(source_dim) == cat_dim; + }) - + source.begin(); + return cat_axis; +} + TensorView* CancelSplitCat::findCancelingSplit( CatOp* cat, std::vector& use_def_chain) { NVF_CHECK(!cat->inputs().empty(), "`cat` has zero inputs: ", cat); - // `frontier` initially contains the preceding Exprs of the `PadOp`s. Then, we + // `PadOp`s that produce `cat`'s inputs. + std::vector pads; + pads.reserve(cat->inputs().size()); + // `frontier` initially contains the `Expr`s that precede `pads`. Then, we // repeatedly try to move the frontier up in lockstep as long as Exprs in the // frontier can be horizontally merged and applied on the unsplit tensor. std::vector frontier; frontier.reserve(cat->inputs().size()); for (Val* in : cat->inputs()) { auto* pad = in->definition()->as(); + pads.push_back(pad); frontier.push_back(pad->in()->definition()); } // Exit the loop when any Expr in `frontier` is a slice or a null. - int64_t split_axis = cat->concatenatedDim(); while (std::none_of(frontier.begin(), frontier.end(), [](Expr* e) { return e == nullptr || e->isA(); })) { - if (!horizontallyMergeable(frontier, std::ref(split_axis))) { + if (!sameOp(frontier)) { return nullptr; } + + auto supported = [](Expr* e) -> bool { + if (e->isA()) { + return true; + } + if (auto* set = dynamic_cast(e)) { + if (set->opType() == LoadStoreOpType::Set) { + return true; + } + } + return false; + }; + if (!supported(frontier[0])) { + return nullptr; + } + use_def_chain.push_back(frontier[0]); // Advance the frontier in lockstep. @@ -233,6 +318,19 @@ TensorView* CancelSplitCat::findCancelingSplit( } } + if (!horizontallyMergeable(frontier, pads)) { + return nullptr; + } + + // Find the corresponding split_axis. + int64_t split_axis = propagateCatAxis( + frontier[0]->as()->out()->getMaybeRFactorDomain(), + pads[0]->out()->as()->getRootDomain(), + cat->concatenatedDim()); + if (split_axis == -1) { + return nullptr; + } + TensorView* split_in = exprsFormSplit(frontier, split_axis); return split_in; } From 3b1b0e4247b83d864009cc4f9f7a7936a96a76c3 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 27 Feb 2024 21:57:45 +0000 Subject: [PATCH 03/13] More TODOs. --- test/test_move_split_cat.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_move_split_cat.cpp b/test/test_move_split_cat.cpp index 8d2469d436d..1285a9cc514 100644 --- a/test/test_move_split_cat.cpp +++ b/test/test_move_split_cat.cpp @@ -407,4 +407,6 @@ TEST_F(MoveSplitCatTest, Cancellable_Issue1768) { EXPECT_TRUE(out_tensors[2].is_alias_of(in_tensor)); } +// FIXME: test multiple split+cat pairs. + } // namespace nvfuser From 89622e3781d9ac399a32f464d8843bd531753efc Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 4 Mar 2024 07:57:08 +0000 Subject: [PATCH 04/13] Check slices earlier. --- csrc/preseg_passes/move_split_cat.cpp | 45 ++++++++++++++------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index e53ef1933f8..71efff0c34a 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -43,7 +43,7 @@ class CancelSplitCat { // Returns true when Exprs between `slices` and `pads` can be horizontally // merged and applied on the input of the split. bool horizontallyMergeable( - const std::vector& slices, + const std::vector& slices, const std::vector& pads); int64_t propagateCatAxis( @@ -90,7 +90,7 @@ bool sameOp(const std::vector& frontier) { } bool CancelSplitCat::horizontallyMergeable( - const std::vector& slices, + const std::vector& slices, const std::vector& pads) { NVF_ERROR(slices.size() == pads.size()); NVF_ERROR(!slices.empty()); @@ -142,20 +142,15 @@ bool CancelSplitCat::horizontallyMergeable( return true; } -// If `exprs` are `SliceOp`s that form a split, returns the base tensor of the +// If `slices` form a split, returns the base tensor of the // split. Returns null otherwise. -TensorView* exprsFormSplit( - const std::vector& exprs, +TensorView* slicesFormSplit( + const std::vector& slices, const int64_t split_axis) { // Checks that all exprs are slices and are based on the // same tensor. Otherwise, they don't form a split. TensorView* split_in = nullptr; - for (Expr* e : exprs) { - auto* slice = dynamic_cast(e); - if (slice == nullptr) { - return nullptr; - } - + for (auto* slice : slices) { if (split_in == nullptr) { split_in = slice->in(); } else if (split_in != slice->in()) { @@ -169,9 +164,8 @@ TensorView* exprsFormSplit( // // `split_ranges[i]` is the slice range of `exprs[i]` for the split axis. std::vector split_ranges; - split_ranges.reserve(exprs.size()); - for (auto i : c10::irange(exprs.size())) { - auto* slice = exprs[i]->as(); + split_ranges.reserve(slices.size()); + for (auto* slice : slices) { const std::vector& slice_ranges = slice->getRanges(); // Check the steps are all one. if (std::any_of( @@ -204,8 +198,7 @@ TensorView* exprsFormSplit( // Due to the limitation of `sameAs` mentioned in #1859, I can't check // split_ranges.back().stop is the same as the dimension size. Below is a // slightly lengthy workaround. - if (!exprs.back() - ->as() + if (!slices.back() ->out() ->getMaybeRFactorDomain()[split_axis] ->definition() @@ -214,7 +207,7 @@ TensorView* exprsFormSplit( ->isZero()) { return nullptr; } - for (size_t i = 0; i + 1 < exprs.size(); i++) { + for (size_t i = 0; i + 1 < slices.size(); i++) { if (!split_ranges[i].stop->sameAs(split_ranges[i + 1].start)) { return nullptr; } @@ -318,20 +311,30 @@ TensorView* CancelSplitCat::findCancelingSplit( } } - if (!horizontallyMergeable(frontier, pads)) { + std::vector slices; + slices.reserve(frontier.size()); + for (Expr* e : frontier) { + auto* slice = dynamic_cast(e); + if (slice == nullptr) { + return nullptr; + } + slices.push_back(slice); + } + + if (!horizontallyMergeable(slices, pads)) { return nullptr; } // Find the corresponding split_axis. - int64_t split_axis = propagateCatAxis( - frontier[0]->as()->out()->getMaybeRFactorDomain(), + const int64_t split_axis = propagateCatAxis( + slices[0]->out()->getMaybeRFactorDomain(), pads[0]->out()->as()->getRootDomain(), cat->concatenatedDim()); if (split_axis == -1) { return nullptr; } - TensorView* split_in = exprsFormSplit(frontier, split_axis); + TensorView* split_in = slicesFormSplit(slices, split_axis); return split_in; } From 456f3d8953cf94fa7f7685ff02a3a64b3594e0be Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 4 Mar 2024 16:29:09 +0000 Subject: [PATCH 05/13] Change tests to use half so it can run on V100. --- csrc/preseg_passes/move_split_cat.cpp | 4 ++-- test/test_move_split_cat.cpp | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index 71efff0c34a..b06b9cb44ff 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -99,11 +99,11 @@ bool CancelSplitCat::horizontallyMergeable( ValGraph& exact_graph = id_model_for_merging_.idGraph(IdMappingMode::EXACT); { const std::vector& first_rfactor = - slices[0]->output(0)->as()->getMaybeRFactorDomain(); + slices[0]->out()->getMaybeRFactorDomain(); size_t num_dims = first_rfactor.size(); for (size_t i = 1; i < slices.size(); i++) { const std::vector& rfactor = - slices[i]->output(0)->as()->getMaybeRFactorDomain(); + slices[i]->out()->getMaybeRFactorDomain(); if (rfactor.size() != num_dims) { return false; } diff --git a/test/test_move_split_cat.cpp b/test/test_move_split_cat.cpp index 1285a9cc514..26fa169531b 100644 --- a/test/test_move_split_cat.cpp +++ b/test/test_move_split_cat.cpp @@ -364,7 +364,7 @@ TEST_F(MoveSplitCatTest, Cancellable_Issue1768) { FusionGuard fg(fusion.get()); TensorView* sdpa_backward_out = - makeContigConcreteTensor({b, h * 3, s, f}, DataType::BFloat16); + makeContigConcreteTensor({b, h * 3, s, f}, DataType::Half); sdpa_backward_out->setAllocationDomain( {sdpa_backward_out->axis(0), sdpa_backward_out->axis(2), @@ -384,7 +384,7 @@ TEST_F(MoveSplitCatTest, Cancellable_Issue1768) { TensorView* cat_out = cat({dq, dk, dv}, /*dim=*/-1); TensorView* sum_out = castOp(DataType::Float, cat_out); sum_out = sum(sum_out, {0, 1}); - sum_out = castOp(DataType::BFloat16, sum_out); + sum_out = castOp(DataType::Half, sum_out); TensorView* view_out = reshape(cat_out, {b, s, h * f * 3}, {b * s, h * f * 3}); TensorView* permute_out = permute(view_out, {1, 0}); @@ -394,7 +394,7 @@ TEST_F(MoveSplitCatTest, Cancellable_Issue1768) { fusion->addOutput(view_out); fusion->addOutput(permute_out); - auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); at::Tensor in_tensor = at::randn({b * h * 3 * s * f}, options) .as_strided({b, h * 3, s, f}, {h * 3 * s * f, f, h * 3 * f, 1}); From e642043a2069626af5b05acbc96b69878523e9e1 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 4 Mar 2024 23:18:02 +0000 Subject: [PATCH 06/13] More tests. --- test/test_move_split_cat.cpp | 68 ++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 6 deletions(-) diff --git a/test/test_move_split_cat.cpp b/test/test_move_split_cat.cpp index 26fa169531b..d1517bd58ec 100644 --- a/test/test_move_split_cat.cpp +++ b/test/test_move_split_cat.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include #include #include @@ -17,6 +18,8 @@ namespace nvfuser { using testing::Contains; +using testing::IsTrue; +using testing::Property; using MoveSplitCatTest = NVFuserTest; @@ -118,6 +121,7 @@ TEST_F(MoveSplitCatTest, Cancellable_PermuteInBetween) { EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); } +namespace { MATCHER(IsPermute, "") { if (auto* set = dynamic_cast(arg)) { if (auto* set_out = dynamic_cast(set->out())) { @@ -126,6 +130,7 @@ MATCHER(IsPermute, "") { } return false; } +} // namespace TEST_F(MoveSplitCatTest, Cancellable_IncompatibleAllocationOrder) { auto fusion = std::make_unique(); @@ -151,11 +156,8 @@ TEST_F(MoveSplitCatTest, Cancellable_IncompatibleAllocationOrder) { // Check the two permutes are merged to one. FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); - ASSERT_EQ(runtime->executors().size(), 1) - << "After merging, the whole fusion can be scheduled unsegmented."; - const FusionExecutor& executor = runtime->executors().front(); - kir::Kernel* kernel = executor.kernel(); - EXPECT_THAT(kernel->exprs(), Contains(IsPermute()).Times(1)); + Fusion* complete_fusion = runtime->fusionSegments()->completeFusion(); + EXPECT_THAT(complete_fusion->exprs(), Contains(IsPermute()).Times(1)); // Due to the incompatible output allocation order, the output can't be an // alias. @@ -407,6 +409,60 @@ TEST_F(MoveSplitCatTest, Cancellable_Issue1768) { EXPECT_TRUE(out_tensors[2].is_alias_of(in_tensor)); } -// FIXME: test multiple split+cat pairs. +TEST_F(MoveSplitCatTest, MultiplePairs) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* merged = makeContigConcreteTensor({4, 6}); + fusion->addInput(merged); + + // Region 0. Mergeable because both slices are permuted in the same way and + // the cat axis matches the split axis. + TensorView* s0 = slice(merged, {0, 0}, {2, 6}); + TensorView* s1 = slice(merged, {2, 0}, {4, 6}); + s0 = permute(s0, {1, 0}); + s1 = permute(s1, {1, 0}); + merged = cat({s0, s1}, /*dim=*/1); + + // Region 1. Not mergeable because the outer dimension is split and the inner + // dimension is catted. + s0 = slice(merged, {0, 0}, {3, 4}); + s1 = slice(merged, {3, 0}, {6, 4}); + s0 = reshape(s0, {3, 4}, {6, 2}); + s1 = reshape(s1, {3, 4}, {6, 2}); + merged = cat({s0, s1}, /*dim=*/1); + + // Region 2. Mergeable because both slices are reshaped in the same way and + // the outer dimension is split and catted. + s0 = slice(merged, {0, 0}, {3, 4}); + s1 = slice(merged, {3, 0}, {6, 4}); + s0 = reshape(s0, {3, 4}, {6, 2}); + s1 = reshape(s1, {3, 4}, {6, 2}); + merged = cat({s0, s1}, /*dim=*/0); + + fusion->addOutput(merged); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({4, 6}, options); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); + Fusion* complete_fusion = runtime->fusionSegments()->completeFusion(); + std::vector exprs = complete_fusion->exprs(); + + // Only region 1 is not mergeable, so we expect to see only that region + // contains two slices and one cat in the pre-segmenter fusion. + EXPECT_THAT( + exprs, Contains(Property(&Expr::isA, IsTrue())).Times(2)); + EXPECT_THAT(exprs, Contains(Property(&Expr::isA, IsTrue())).Times(1)); + // The two permutes in region 0 are expected to be merged. + EXPECT_THAT(exprs, Contains(IsPermute()).Times(1)); + // The two reshapes in region 1 stay as is and the two reshapes in region 2 + // are merged. Therefore, three reshapes in total. + EXPECT_THAT(exprs, Contains(Property(&Expr::isA, IsTrue())).Times(3)); +} } // namespace nvfuser From 2572b8963afd1a55e8743b5cd597fa9ed1f01879 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 4 Mar 2024 23:58:04 +0000 Subject: [PATCH 07/13] Build only exact graphs. --- csrc/preseg_passes/move_split_cat.cpp | 52 +++++++++++++++------------ 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index b06b9cb44ff..8f37ee9c10b 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -26,14 +26,20 @@ class CancelSplitCat { public: CancelSplitCat(Fusion* fusion) : fusion_(fusion), - id_model_for_merging_( + id_model_( fusion, - /*build_graphs=*/true, + /*build_graphs=*/false, /*allow_self_mapping=*/true), - id_model_for_propagation_( + id_model_for_merging_( fusion, - /*build_graphs=*/true, - /*allow_self_mapping=*/true) {} + /*build_graphs=*/false, + /*allow_self_mapping=*/true) { + id_model_.buildExactGraph(); + exact_graph_ = &id_model_.idGraph(IdMappingMode::EXACT); + id_model_for_merging_.buildExactGraph(); + exact_graph_for_merging_ = + &id_model_for_merging_.idGraph(IdMappingMode::EXACT); + } // Finds all cancellable pairs, cancels them and horizontallly // merges ops in between. @@ -76,10 +82,16 @@ class CancelSplitCat { Fusion* fusion_; - // TODO(wujingyue): keep two `IdGraph`s not two `IdModel`s. An `IdModel` - // contains multiple graphs and we only care about the exact graph in it. - IdModel id_model_for_merging_; - IdModel id_model_for_propagation_; + // `id_model_` and `exact_graph_` are supposed to be read-only and reflect the + // original fusion. + IdModel id_model_; // Holds *exact_graph_. + ValGraph* exact_graph_; + + // `id_model_for_merging_` and `exact_graph_for_merging_` are used for + // `horizontallyMergeable`, which unionizes IterDomains in slices and checks + // IterDomains in pads are unionized in the same way. + IdModel id_model_for_merging_; // Holds *exact_graph_for_merging. + ValGraph* exact_graph_for_merging_; }; bool sameOp(const std::vector& frontier) { @@ -95,8 +107,6 @@ bool CancelSplitCat::horizontallyMergeable( NVF_ERROR(slices.size() == pads.size()); NVF_ERROR(!slices.empty()); - // FIXME: make it a class member. - ValGraph& exact_graph = id_model_for_merging_.idGraph(IdMappingMode::EXACT); { const std::vector& first_rfactor = slices[0]->out()->getMaybeRFactorDomain(); @@ -108,7 +118,7 @@ bool CancelSplitCat::horizontallyMergeable( return false; } for (size_t j = 0; j < num_dims; j++) { - exact_graph.mapVals(first_rfactor[j], rfactor[j]); + exact_graph_for_merging_->mapVals(first_rfactor[j], rfactor[j]); } } } @@ -131,7 +141,7 @@ bool CancelSplitCat::horizontallyMergeable( return false; } for (size_t j = 0; j < num_dims; j++) { - if (!exact_graph.disjointValSets().strictAreMapped( + if (!exact_graph_for_merging_->disjointValSets().strictAreMapped( first_root[j], root[j])) { return false; } @@ -220,14 +230,12 @@ int64_t CancelSplitCat::propagateCatAxis( const std::vector& source, const std::vector& destination, int64_t cat_axis) { - ValGraph& exact_graph = - id_model_for_propagation_.idGraph(IdMappingMode::EXACT); - ValGroup cat_dim = exact_graph.toGroup(destination[cat_axis]); + ValGroup cat_dim = exact_graph_->toGroup(destination[cat_axis]); while ( std::none_of(source.begin(), source.end(), [&](IterDomain* source_dim) { - return exact_graph.toGroup(source_dim) == cat_dim; + return exact_graph_->toGroup(source_dim) == cat_dim; })) { - const ExprGroups& defining_groups = exact_graph.getDefinitions(cat_dim); + const ExprGroups& defining_groups = exact_graph_->getDefinitions(cat_dim); if (defining_groups.size() != 1) { return -1; } @@ -235,13 +243,13 @@ int64_t CancelSplitCat::propagateCatAxis( Expr* def = defining_group->front(); // FIXME: make this a function so we can early return. if (Split* split = dynamic_cast(def)) { - if (exact_graph.toGroup(split->outer()) == cat_dim) { - cat_dim = exact_graph.toGroup(split->in()); + if (exact_graph_->toGroup(split->outer()) == cat_dim) { + cat_dim = exact_graph_->toGroup(split->in()); } else { return -1; } } else if (Merge* merge = dynamic_cast(def)) { - cat_dim = exact_graph.toGroup(merge->outer()); + cat_dim = exact_graph_->toGroup(merge->outer()); } else { return -1; } @@ -251,7 +259,7 @@ int64_t CancelSplitCat::propagateCatAxis( source.begin(), source.end(), [&](IterDomain* source_dim) { - return exact_graph.toGroup(source_dim) == cat_dim; + return exact_graph_->toGroup(source_dim) == cat_dim; }) - source.begin(); return cat_axis; From 65e355930c36090d4f5e1691b6eaefb1b64b4cd7 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 5 Mar 2024 00:12:09 +0000 Subject: [PATCH 08/13] Simplify propagateCatAxis. --- csrc/preseg_passes/move_split_cat.cpp | 48 ++++++++++++--------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index 8f37ee9c10b..f007c28db8f 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -52,10 +52,9 @@ class CancelSplitCat { const std::vector& slices, const std::vector& pads); - int64_t propagateCatAxis( + int64_t computeSplitAxis( const std::vector& source, - const std::vector& destination, - int64_t cat_axis); + IterDomain* cat_id); // Finds the canceling split of `cat` and returns the input TensorView of the // split. A split (implemented as multiple `slice`s) and a cat cancel when @@ -226,16 +225,14 @@ TensorView* slicesFormSplit( return split_in; } -int64_t CancelSplitCat::propagateCatAxis( +int64_t CancelSplitCat::computeSplitAxis( const std::vector& source, - const std::vector& destination, - int64_t cat_axis) { - ValGroup cat_dim = exact_graph_->toGroup(destination[cat_axis]); - while ( - std::none_of(source.begin(), source.end(), [&](IterDomain* source_dim) { - return exact_graph_->toGroup(source_dim) == cat_dim; - })) { - const ExprGroups& defining_groups = exact_graph_->getDefinitions(cat_dim); + IterDomain* cat_id) { + ValGroup cat_group = exact_graph_->toGroup(cat_id); + while (std::none_of(source.begin(), source.end(), [&](IterDomain* source_id) { + return exact_graph_->toGroup(source_id) == cat_group; + })) { + const ExprGroups& defining_groups = exact_graph_->getDefinitions(cat_group); if (defining_groups.size() != 1) { return -1; } @@ -243,26 +240,25 @@ int64_t CancelSplitCat::propagateCatAxis( Expr* def = defining_group->front(); // FIXME: make this a function so we can early return. if (Split* split = dynamic_cast(def)) { - if (exact_graph_->toGroup(split->outer()) == cat_dim) { - cat_dim = exact_graph_->toGroup(split->in()); + if (exact_graph_->toGroup(split->outer()) == cat_group) { + cat_group = exact_graph_->toGroup(split->in()); } else { return -1; } } else if (Merge* merge = dynamic_cast(def)) { - cat_dim = exact_graph_->toGroup(merge->outer()); + cat_group = exact_graph_->toGroup(merge->outer()); } else { return -1; } } - cat_axis = std::find_if( - source.begin(), - source.end(), - [&](IterDomain* source_dim) { - return exact_graph_->toGroup(source_dim) == cat_dim; - }) - + return std::find_if( + source.begin(), + source.end(), + [&](IterDomain* source_id) { + return exact_graph_->toGroup(source_id) == cat_group; + }) - source.begin(); - return cat_axis; } TensorView* CancelSplitCat::findCancelingSplit( @@ -333,11 +329,11 @@ TensorView* CancelSplitCat::findCancelingSplit( return nullptr; } - // Find the corresponding split_axis. - const int64_t split_axis = propagateCatAxis( + // Compute the corresponding split_axis. + const int64_t cat_axis = cat->concatenatedDim(); + const int64_t split_axis = computeSplitAxis( slices[0]->out()->getMaybeRFactorDomain(), - pads[0]->out()->as()->getRootDomain(), - cat->concatenatedDim()); + pads[0]->out()->as()->getRootDomain()[cat_axis]); if (split_axis == -1) { return nullptr; } From 41653f50624715af065a9677de5879a068cae815 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 5 Mar 2024 07:52:07 +0000 Subject: [PATCH 09/13] Refactors and comments. --- csrc/preseg_passes/move_split_cat.cpp | 206 +++++++++++++++++++------- 1 file changed, 150 insertions(+), 56 deletions(-) diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index f007c28db8f..f2543d6414c 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -46,14 +46,72 @@ class CancelSplitCat { void run(); private: - // Returns true when Exprs between `slices` and `pads` can be horizontally - // merged and applied on the input of the split. - bool horizontallyMergeable( + // Returns true when the def-use chain from slices[i] to pads[i] apply the + // same IterDomain transforms as the one from slices[j] to pads[j]. This is a + // necessary condition for horizontally merging the chains. + // + // Pre-condition: this is called after findPairingSplit so we know these + // chains contain the same sequence of op types and attributes. + bool sameIterDomainTransforms( const std::vector& slices, const std::vector& pads); - int64_t computeSplitAxis( - const std::vector& source, + // Imagine we "zip" the cat upwards as following: + // + // s0, s1 = split(in) + // s0 = unary_0(s0) + // s1 = unary_0(s1) + // ... + // s0 = unary_k(s0) + // s1 = unary_k(s1) + // s = cat({s0, s1}) + // + // ==> + // + // s0, s1 = split(in) + // s = cat({s0, s1}) + // s = unary_0(s) + // ... + // s = unary_k(s) + // + // This function returns the concatenated axis of the new cat so the above + // transform preserves the semantics. This axis will then be compared with the + // split axis to determine whether the split and the cat cancel out. + // + // If we can't zip the cat up to the split outputs (see one of the following + // examples), this function returns -1. + // + // Before calling this function, we already checked the chains contain the + // same sequence of op type and attributes and transform IterDomains in the + // same way. So this function takes the rfactor domain of any one of the + // slices and the catted IterDomain at the end of that chain. + // + // Example 1: + // t = permute(slice, {1, 2, 0}) + // out = cat({t, ...}, 1) + // + // Returns 2 because the catted dimension (dimension 1 of `t1`) is permuted + // from dimension 2 of `slice`. + // + // Example 2: + // t = reshape(slice, {2, 3, 5}, {6, 5}) + // out = cat({t, ...}, 1} + // + // Returns 2 because the catted dimension comes from dimension 2 of `slice`. + // + // Example 3: + // t = reshape(slice, {2, 3}, {6}) + // out = cat({t, ...}, 0} + // + // Returns 0 because `slice`'s dimension 0 is the outer dimension. + // + // Example 4: + // t = reshape(slice, {6}, {2, 3}) + // out = cat({t, ...}, 1} + // + // Returns -1 because `out`'s dimension 1 is the inner dimension. + int64_t computeCatAxisAfterZipping( + const std::vector& slice_rfactor, IterDomain* cat_id); // Finds the canceling split of `cat` and returns the input TensorView of the @@ -75,9 +133,9 @@ class CancelSplitCat { // out = cat([t0, t1], dim=0) // // In addition to returning `in`, findCancelingSplit(out) puts `t0`'s defining - // `permute` into `use_def_chain` so the caller can reconstruct `out` by - // replaying `use_def_chain` (in reverse order) on `in`. - TensorView* findCancelingSplit(CatOp* cat, std::vector& use_def_chain); + // `permute` into `def_use_chain` so the caller can reconstruct `out` by + // replaying `def_use_chain` on `in`. + TensorView* findCancelingSplit(CatOp* cat, std::vector& def_use_chain); Fusion* fusion_; @@ -100,7 +158,7 @@ bool sameOp(const std::vector& frontier) { }) == frontier.end(); } -bool CancelSplitCat::horizontallyMergeable( +bool CancelSplitCat::sameIterDomainTransforms( const std::vector& slices, const std::vector& pads) { NVF_ERROR(slices.size() == pads.size()); @@ -225,50 +283,66 @@ TensorView* slicesFormSplit( return split_in; } -int64_t CancelSplitCat::computeSplitAxis( - const std::vector& source, +int64_t CancelSplitCat::computeCatAxisAfterZipping( + const std::vector& slice_rfactor, IterDomain* cat_id) { ValGroup cat_group = exact_graph_->toGroup(cat_id); - while (std::none_of(source.begin(), source.end(), [&](IterDomain* source_id) { - return exact_graph_->toGroup(source_id) == cat_group; - })) { - const ExprGroups& defining_groups = exact_graph_->getDefinitions(cat_group); - if (defining_groups.size() != 1) { - return -1; + while (cat_group != nullptr) { + // If `cat_group` contains a slice rfactor ID, return the index of that ID. + auto i = std::find_if( + slice_rfactor.begin(), slice_rfactor.end(), [&](IterDomain* id) { + return exact_graph_->toGroup(id) == cat_group; + }); + if (i != slice_rfactor.end()) { + return i - slice_rfactor.begin(); } - ExprGroup defining_group = defining_groups.front(); - Expr* def = defining_group->front(); - // FIXME: make this a function so we can early return. - if (Split* split = dynamic_cast(def)) { - if (exact_graph_->toGroup(split->outer()) == cat_group) { - cat_group = exact_graph_->toGroup(split->in()); - } else { - return -1; + + // Conceptually zip `cat_group` over its definition. + auto cat_group_after_zipping = [&](ValGroup cat_group) -> ValGroup { + const ExprGroups& defining_groups = + exact_graph_->getDefinitions(cat_group); + if (defining_groups.size() != 1) { + return nullptr; } - } else if (Merge* merge = dynamic_cast(def)) { - cat_group = exact_graph_->toGroup(merge->outer()); - } else { - return -1; - } + ExprGroup defining_group = defining_groups.front(); + // Pick an arbitrary Expr from defining_group as the representative. + Expr* def = defining_group->front(); + + if (Split* split = dynamic_cast(def)) { + if (exact_graph_->toGroup(split->outer()) == cat_group) { + return exact_graph_->toGroup(split->in()); + } + return nullptr; + } + + if (Merge* merge = dynamic_cast(def)) { + return exact_graph_->toGroup(merge->outer()); + } + + return nullptr; + }; + cat_group = cat_group_after_zipping(cat_group); } - return std::find_if( - source.begin(), - source.end(), - [&](IterDomain* source_id) { - return exact_graph_->toGroup(source_id) == cat_group; - }) - - source.begin(); + return -1; } -TensorView* CancelSplitCat::findCancelingSplit( - CatOp* cat, - std::vector& use_def_chain) { +// Finds the pairing split of `cat` by traversing the use-def chains. If found, +// returns the slices of the pairing split and `cat`'s preceding `PadOp`s. This +// function does some basic checks like: +// 1. Ops between the chains must have the same op type and attributes. +// 2. Chains must end with slices. +// However, these checks are necessary but not sufficient to guarantee the +// pairing split is canceling. To make that decision, the caller has to further +// inspect the ops in between. +std::optional, std::vector>> +findPairingSplit(CatOp* cat) { NVF_CHECK(!cat->inputs().empty(), "`cat` has zero inputs: ", cat); // `PadOp`s that produce `cat`'s inputs. std::vector pads; pads.reserve(cat->inputs().size()); + // `frontier` initially contains the `Expr`s that precede `pads`. Then, we // repeatedly try to move the frontier up in lockstep as long as Exprs in the // frontier can be horizontally merged and applied on the unsplit tensor. @@ -285,9 +359,12 @@ TensorView* CancelSplitCat::findCancelingSplit( return e == nullptr || e->isA(); })) { if (!sameOp(frontier)) { - return nullptr; + return std::nullopt; } + // We can probably extend this list to include many other unary ops. + // Currently, I limit this to only reshapes and permutes to reduce blast + // radius. auto supported = [](Expr* e) -> bool { if (e->isA()) { return true; @@ -300,11 +377,9 @@ TensorView* CancelSplitCat::findCancelingSplit( return false; }; if (!supported(frontier[0])) { - return nullptr; + return std::nullopt; } - use_def_chain.push_back(frontier[0]); - // Advance the frontier in lockstep. for (Expr*& e : frontier) { NVF_ERROR( @@ -320,41 +395,60 @@ TensorView* CancelSplitCat::findCancelingSplit( for (Expr* e : frontier) { auto* slice = dynamic_cast(e); if (slice == nullptr) { - return nullptr; + return std::nullopt; } slices.push_back(slice); } - if (!horizontallyMergeable(slices, pads)) { + return std::make_pair(slices, pads); +} + +TensorView* CancelSplitCat::findCancelingSplit( + CatOp* cat, + std::vector& def_use_chain) { + auto heads_and_tails = findPairingSplit(cat); + if (!heads_and_tails.has_value()) { return nullptr; } + std::vector slices; + std::vector pads; + std::tie(slices, pads) = *heads_and_tails; - // Compute the corresponding split_axis. - const int64_t cat_axis = cat->concatenatedDim(); - const int64_t split_axis = computeSplitAxis( + if (!sameIterDomainTransforms(slices, pads)) { + return nullptr; + } + + int64_t cat_axis = cat->concatenatedDim(); + cat_axis = computeCatAxisAfterZipping( slices[0]->out()->getMaybeRFactorDomain(), pads[0]->out()->as()->getRootDomain()[cat_axis]); - if (split_axis == -1) { + if (cat_axis == -1) { + return nullptr; + } + + TensorView* split_in = slicesFormSplit(slices, cat_axis); + if (split_in == nullptr) { return nullptr; } - TensorView* split_in = slicesFormSplit(slices, split_axis); + std::vector first_chain = + StmtSort::getExprsBetween({slices[0]->out()}, {pads[0]->in()}); + def_use_chain.swap(first_chain); return split_in; } void CancelSplitCat::run() { std::vector exprs = fusion_->exprs(); for (auto* cat : ir_utils::filterByType(exprs)) { - std::vector use_def_chain; - TensorView* split_in = findCancelingSplit(cat, std::ref(use_def_chain)); + std::vector def_use_chain; + TensorView* split_in = findCancelingSplit(cat, std::ref(def_use_chain)); if (split_in == nullptr) { continue; } Val* merged_out = split_in; - for (auto i = use_def_chain.rbegin(), end = use_def_chain.rend(); i != end; - i++) { - Expr* merged = replayExprWithNewInput(*i, merged_out); + for (Expr* e : def_use_chain) { + Expr* merged = replayExprWithNewInput(e, merged_out); NVF_ERROR( merged->outputs().size() == 1, "Currently, we merge only unary ops, so it would be a programming " From 0505b762b686a9ba3d6d583d597e0a1c5c4b6331 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 5 Mar 2024 08:17:08 +0000 Subject: [PATCH 10/13] Fix clang-tidy. --- csrc/preseg_passes/move_split_cat.cpp | 38 ++++++++++++--------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index f2543d6414c..51512a8dee0 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -35,10 +35,7 @@ class CancelSplitCat { /*build_graphs=*/false, /*allow_self_mapping=*/true) { id_model_.buildExactGraph(); - exact_graph_ = &id_model_.idGraph(IdMappingMode::EXACT); id_model_for_merging_.buildExactGraph(); - exact_graph_for_merging_ = - &id_model_for_merging_.idGraph(IdMappingMode::EXACT); } // Finds all cancellable pairs, cancels them and horizontallly @@ -139,16 +136,14 @@ class CancelSplitCat { Fusion* fusion_; - // `id_model_` and `exact_graph_` are supposed to be read-only and reflect the + // `id_model_` is supposed to be read-only and reflect the // original fusion. - IdModel id_model_; // Holds *exact_graph_. - ValGraph* exact_graph_; - - // `id_model_for_merging_` and `exact_graph_for_merging_` are used for - // `horizontallyMergeable`, which unionizes IterDomains in slices and checks - // IterDomains in pads are unionized in the same way. - IdModel id_model_for_merging_; // Holds *exact_graph_for_merging. - ValGraph* exact_graph_for_merging_; + IdModel id_model_; + + // `id_model_for_merging_` is used for + // `sameIterDomainTransforms`, which unionizes IterDomains in slices and + // checks IterDomains in pads are unionized in the same way. + IdModel id_model_for_merging_; }; bool sameOp(const std::vector& frontier) { @@ -164,6 +159,7 @@ bool CancelSplitCat::sameIterDomainTransforms( NVF_ERROR(slices.size() == pads.size()); NVF_ERROR(!slices.empty()); + ValGraph& exact_graph = id_model_for_merging_.idGraph(IdMappingMode::EXACT); { const std::vector& first_rfactor = slices[0]->out()->getMaybeRFactorDomain(); @@ -175,7 +171,7 @@ bool CancelSplitCat::sameIterDomainTransforms( return false; } for (size_t j = 0; j < num_dims; j++) { - exact_graph_for_merging_->mapVals(first_rfactor[j], rfactor[j]); + exact_graph.mapVals(first_rfactor[j], rfactor[j]); } } } @@ -198,7 +194,7 @@ bool CancelSplitCat::sameIterDomainTransforms( return false; } for (size_t j = 0; j < num_dims; j++) { - if (!exact_graph_for_merging_->disjointValSets().strictAreMapped( + if (!exact_graph.disjointValSets().strictAreMapped( first_root[j], root[j])) { return false; } @@ -286,12 +282,13 @@ TensorView* slicesFormSplit( int64_t CancelSplitCat::computeCatAxisAfterZipping( const std::vector& slice_rfactor, IterDomain* cat_id) { - ValGroup cat_group = exact_graph_->toGroup(cat_id); + ValGraph& exact_graph = id_model_.idGraph(IdMappingMode::EXACT); + ValGroup cat_group = exact_graph.toGroup(cat_id); while (cat_group != nullptr) { // If `cat_group` contains a slice rfactor ID, return the index of that ID. auto i = std::find_if( slice_rfactor.begin(), slice_rfactor.end(), [&](IterDomain* id) { - return exact_graph_->toGroup(id) == cat_group; + return exact_graph.toGroup(id) == cat_group; }); if (i != slice_rfactor.end()) { return i - slice_rfactor.begin(); @@ -299,8 +296,7 @@ int64_t CancelSplitCat::computeCatAxisAfterZipping( // Conceptually zip `cat_group` over its definition. auto cat_group_after_zipping = [&](ValGroup cat_group) -> ValGroup { - const ExprGroups& defining_groups = - exact_graph_->getDefinitions(cat_group); + const ExprGroups& defining_groups = exact_graph.getDefinitions(cat_group); if (defining_groups.size() != 1) { return nullptr; } @@ -309,14 +305,14 @@ int64_t CancelSplitCat::computeCatAxisAfterZipping( Expr* def = defining_group->front(); if (Split* split = dynamic_cast(def)) { - if (exact_graph_->toGroup(split->outer()) == cat_group) { - return exact_graph_->toGroup(split->in()); + if (exact_graph.toGroup(split->outer()) == cat_group) { + return exact_graph.toGroup(split->in()); } return nullptr; } if (Merge* merge = dynamic_cast(def)) { - return exact_graph_->toGroup(merge->outer()); + return exact_graph.toGroup(merge->outer()); } return nullptr; From a5a4e923f828ce234a50175380585ec318c77948 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 5 Mar 2024 12:23:55 -0800 Subject: [PATCH 11/13] Update csrc/preseg_passes/move_split_cat.cpp Co-authored-by: jjsjann123 --- csrc/preseg_passes/move_split_cat.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index 51512a8dee0..ed981f3dd0a 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -282,7 +282,7 @@ TensorView* slicesFormSplit( int64_t CancelSplitCat::computeCatAxisAfterZipping( const std::vector& slice_rfactor, IterDomain* cat_id) { - ValGraph& exact_graph = id_model_.idGraph(IdMappingMode::EXACT); + const ValGraph& exact_graph = id_model_.idGraph(IdMappingMode::EXACT); ValGroup cat_group = exact_graph.toGroup(cat_id); while (cat_group != nullptr) { // If `cat_group` contains a slice rfactor ID, return the index of that ID. From ec85e5a19561409aa61dbbe2a01b2c9f8b25f295 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 5 Mar 2024 23:20:07 +0000 Subject: [PATCH 12/13] Address some review comments. --- csrc/preseg_passes/move_split_cat.cpp | 8 +++++++- test/test_move_split_cat.cpp | 24 +++++++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index ed981f3dd0a..d1f15b30c1d 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -101,9 +101,15 @@ class CancelSplitCat { // out = cat({t, ...}, 0} // // Returns 0 because `slice`'s dimension 0 is the outer dimension. - // + // Example 4: // t = reshape(slice, {6}, {2, 3}) + // out = cat({t, ...}, 0} + // + // Returns 0 because `out`'s dimension 0 is the outer dimension. + // + // Example 5: + // t = reshape(slice, {6}, {2, 3}) // out = cat({t, ...}, 1} // // Returns -1 because `out`'s dimension 1 is the inner dimension. diff --git a/test/test_move_split_cat.cpp b/test/test_move_split_cat.cpp index d1517bd58ec..4e540ed27d7 100644 --- a/test/test_move_split_cat.cpp +++ b/test/test_move_split_cat.cpp @@ -23,7 +23,7 @@ using testing::Property; using MoveSplitCatTest = NVFuserTest; -TEST_F(MoveSplitCatTest, Cancellable_Adjacent) { +TEST_F(MoveSplitCatTest, Cancellable_SplitImmediatelyFollowedByCat) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -46,6 +46,28 @@ TEST_F(MoveSplitCatTest, Cancellable_Adjacent) { EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); } +TEST_F(MoveSplitCatTest, Noncancellable_DifferentOrder) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({2, 6}); + TensorView* s0 = slice(in, {0, 0}, {2, 3}); + TensorView* s1 = slice(in, {0, 3}, {2, 6}); + TensorView* out = cat({s1, s0}, /*dim=*/-1); + + fusion->addInput(in); + fusion->addOutput(out); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({2, 6}, options); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_FALSE(out_tensors[0].is_alias_of(in_tensor)); +} + TEST_F(MoveSplitCatTest, Noncancellable_SliceAmountAndPaddingAmountMismatch) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); From 5334b223e3e5f050b1b83f5fe0350022a5f4eb93 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 6 Mar 2024 07:11:24 +0000 Subject: [PATCH 13/13] Speed up sameIterDomainTransforms by mapping only the cat axis. --- csrc/preseg_passes/move_split_cat.cpp | 59 ++++++++++++++------------- test/test_move_split_cat.cpp | 44 ++++++++++++++++---- 2 files changed, 65 insertions(+), 38 deletions(-) diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index d1f15b30c1d..2a854f6cf9b 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -51,7 +51,8 @@ class CancelSplitCat { // chains contain the same sequence of op types and attributes. bool sameIterDomainTransforms( const std::vector& slices, - const std::vector& pads); + const std::vector& pads, + int64_t cat_axis); // Imagine we "zip" the cat upwards as following: // @@ -161,54 +162,54 @@ bool sameOp(const std::vector& frontier) { bool CancelSplitCat::sameIterDomainTransforms( const std::vector& slices, - const std::vector& pads) { + const std::vector& pads, + const int64_t cat_axis) { NVF_ERROR(slices.size() == pads.size()); NVF_ERROR(!slices.empty()); ValGraph& exact_graph = id_model_for_merging_.idGraph(IdMappingMode::EXACT); { - const std::vector& first_rfactor = - slices[0]->out()->getMaybeRFactorDomain(); - size_t num_dims = first_rfactor.size(); - for (size_t i = 1; i < slices.size(); i++) { - const std::vector& rfactor = - slices[i]->out()->getMaybeRFactorDomain(); - if (rfactor.size() != num_dims) { - return false; - } - for (size_t j = 0; j < num_dims; j++) { - exact_graph.mapVals(first_rfactor[j], rfactor[j]); - } + // Map pads[i0].root[cat_axis] and pads[i1].root[cat_axis]. Other axes were + // already mapped due to the `cat` when the IdModel was built. + const std::vector& first_root = + pads[0]->out()->as()->getRootDomain(); + for (size_t i = 1; i < pads.size(); i++) { + const std::vector& other_root = + pads[i]->out()->as()->getRootDomain(); + NVF_ERROR(first_root.size() == other_root.size()); + exact_graph.mapVals(first_root[cat_axis], other_root[cat_axis]); } } - for (PadOp* pad : pads) { - auto* pad_out = pad->out()->as(); - if (id_model_for_merging_.hasSelfMapping(pad_out)) { + // The above code block only maps IterDomains across chains. If a self mapping + // is detected at this point, it's likely due to some IterDomains are permuted + // diffrently between two chains. See + // MoveSplitCatTest.Noncancellable_PermutedDifferently for an example. + for (auto* slice : slices) { + if (id_model_for_merging_.hasSelfMapping(slice->out())) { return false; } } { - const std::vector& first_root = - pads[0]->out()->as()->getRootDomain(); - size_t num_dims = first_root.size(); - for (size_t i = 1; i < pads.size(); i++) { - const std::vector& root = - pads[i]->out()->as()->getRootDomain(); - if (root.size() != num_dims) { + const std::vector& first_rfactor = + slices[0]->out()->getMaybeRFactorDomain(); + size_t num_dims = first_rfactor.size(); + for (size_t i = 1; i < slices.size(); i++) { + const std::vector& other_rfactor = + slices[i]->out()->getMaybeRFactorDomain(); + if (other_rfactor.size() != num_dims) { return false; } for (size_t j = 0; j < num_dims; j++) { if (!exact_graph.disjointValSets().strictAreMapped( - first_root[j], root[j])) { + first_rfactor[j], other_rfactor[j])) { return false; } } } + return true; } - - return true; } // If `slices` form a split, returns the base tensor of the @@ -416,11 +417,11 @@ TensorView* CancelSplitCat::findCancelingSplit( std::vector pads; std::tie(slices, pads) = *heads_and_tails; - if (!sameIterDomainTransforms(slices, pads)) { + int64_t cat_axis = cat->concatenatedDim(); + if (!sameIterDomainTransforms(slices, pads, cat_axis)) { return nullptr; } - int64_t cat_axis = cat->concatenatedDim(); cat_axis = computeCatAxisAfterZipping( slices[0]->out()->getMaybeRFactorDomain(), pads[0]->out()->as()->getRootDomain()[cat_axis]); diff --git a/test/test_move_split_cat.cpp b/test/test_move_split_cat.cpp index 4e540ed27d7..fe4ccb32063 100644 --- a/test/test_move_split_cat.cpp +++ b/test/test_move_split_cat.cpp @@ -68,6 +68,30 @@ TEST_F(MoveSplitCatTest, Noncancellable_DifferentOrder) { EXPECT_FALSE(out_tensors[0].is_alias_of(in_tensor)); } +TEST_F(MoveSplitCatTest, Cancellable_SetWithoutPermute) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({2, 5}); + TensorView* s0 = slice(in, {0, 0}, {2, 2}); + TensorView* s1 = slice(in, {0, 2}, {2, 5}); + s0 = set(s0); + s1 = set(s1); + TensorView* out = cat({s0, s1}, /*dim=*/-1); + + fusion->addInput(in); + fusion->addOutput(out); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({2, 5}, options); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); +} + TEST_F(MoveSplitCatTest, Noncancellable_SliceAmountAndPaddingAmountMismatch) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -274,20 +298,22 @@ TEST_F(MoveSplitCatTest, Noncancellable_PermutedDifferently) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - TensorView* in = makeContigConcreteTensor({2, 2, 2, 10}); - TensorView* s0 = slice(in, {0, 0, 0, 0}, {2, 2, 2, 2}); - TensorView* s1 = slice(in, {0, 0, 0, 2}, {2, 2, 2, 5}); - TensorView* s2 = slice(in, {0, 0, 0, 5}, {2, 2, 2, 10}); - s0 = permute(s0, {2, 1, 0, 3}); - s1 = permute(s1, {1, 0, 2, 3}); - s2 = permute(s2, {2, 1, 0, 3}); - TensorView* out = cat({s0, s1, s2}, /*dim=*/-1); + TensorView* in = makeContigConcreteTensor({4, 2}); + TensorView* s0 = slice(in, {0, 0}, {2, 2}); + s0 = set(s0); + s0 = reshape(s0, {2, 2}, {4}); + + TensorView* s1 = slice(in, {2, 0}, {4, 2}); + s1 = permute(s1, {1, 0}); + s1 = reshape(s1, {2, 2}, {4}); + + TensorView* out = cat({s0, s1}, /*dim=*/0); fusion->addInput(in); fusion->addOutput(out); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor in_tensor = at::randn({2, 2, 2, 10}, options); + at::Tensor in_tensor = at::randn({4, 2}, options); FusionExecutorCache fec(std::move(fusion)); auto out_tensors = fec.runFusionWithInputs({in_tensor});