diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index ad92030f4ae..2a854f6cf9b 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -26,18 +26,97 @@ class CancelSplitCat { public: CancelSplitCat(Fusion* fusion) : fusion_(fusion), - id_model_(fusion, /*build_graphs=*/true, /*allow_self_mapping=*/true) {} + id_model_( + fusion, + /*build_graphs=*/false, + /*allow_self_mapping=*/true), + id_model_for_merging_( + fusion, + /*build_graphs=*/false, + /*allow_self_mapping=*/true) { + id_model_.buildExactGraph(); + id_model_for_merging_.buildExactGraph(); + } // Finds all cancellable pairs, cancels them and horizontallly // merges ops in between. void run(); private: - // Returns true when Exprs in the frontier can be horizontally merged and - // applied on the unsplit tensor. - bool horizontallyMergeable( - const std::vector& frontier, - int64_t& split_axis); + // Returns true when the def-use chain from slices[i] to pads[i] apply the + // same IterDomain transforms as the one from slices[j] to pads[j]. This is a + // necessary condition for horizontally merging the chains. + // + // Pre-condition: this is called after findPairingSplit so we know these + // chains contain the same sequence of op types and attributes. + bool sameIterDomainTransforms( + const std::vector& slices, + const std::vector& pads, + int64_t cat_axis); + + // Imagine we "zip" the cat upwards as following: + // + // s0, s1 = split(in) + // s0 = unary_0(s0) + // s1 = unary_0(s1) + // ... + // s0 = unary_k(s0) + // s1 = unary_k(s1) + // s = cat({s0, s1}) + // + // ==> + // + // s0, s1 = split(in) + // s = cat({s0, s1}) + // s = unary_0(s) + // ... + // s = unary_k(s) + // + // This function returns the concatenated axis of the new cat so the above + // transform preserves the semantics. This axis will then be compared with the + // split axis to determine whether the split and the cat cancel out. + // + // If we can't zip the cat up to the split outputs (see one of the following + // examples), this function returns -1. + // + // Before calling this function, we already checked the chains contain the + // same sequence of op type and attributes and transform IterDomains in the + // same way. So this function takes the rfactor domain of any one of the + // slices and the catted IterDomain at the end of that chain. + // + // Example 1: + // t = permute(slice, {1, 2, 0}) + // out = cat({t, ...}, 1) + // + // Returns 2 because the catted dimension (dimension 1 of `t1`) is permuted + // from dimension 2 of `slice`. + // + // Example 2: + // t = reshape(slice, {2, 3, 5}, {6, 5}) + // out = cat({t, ...}, 1} + // + // Returns 2 because the catted dimension comes from dimension 2 of `slice`. + // + // Example 3: + // t = reshape(slice, {2, 3}, {6}) + // out = cat({t, ...}, 0} + // + // Returns 0 because `slice`'s dimension 0 is the outer dimension. + + // Example 4: + // t = reshape(slice, {6}, {2, 3}) + // out = cat({t, ...}, 0} + // + // Returns 0 because `out`'s dimension 0 is the outer dimension. + // + // Example 5: + // t = reshape(slice, {6}, {2, 3}) + // out = cat({t, ...}, 1} + // + // Returns -1 because `out`'s dimension 1 is the inner dimension. + int64_t computeCatAxisAfterZipping( + const std::vector& slice_rfactor, + IterDomain* cat_id); // Finds the canceling split of `cat` and returns the input TensorView of the // split. A split (implemented as multiple `slice`s) and a cat cancel when @@ -58,79 +137,90 @@ class CancelSplitCat { // out = cat([t0, t1], dim=0) // // In addition to returning `in`, findCancelingSplit(out) puts `t0`'s defining - // `permute` into `use_def_chain` so the caller can reconstruct `out` by - // replaying `use_def_chain` (in reverse order) on `in`. - TensorView* findCancelingSplit(CatOp* cat, std::vector& use_def_chain); + // `permute` into `def_use_chain` so the caller can reconstruct `out` by + // replaying `def_use_chain` on `in`. + TensorView* findCancelingSplit(CatOp* cat, std::vector& def_use_chain); Fusion* fusion_; + // `id_model_` is supposed to be read-only and reflect the + // original fusion. IdModel id_model_; + + // `id_model_for_merging_` is used for + // `sameIterDomainTransforms`, which unionizes IterDomains in slices and + // checks IterDomains in pads are unionized in the same way. + IdModel id_model_for_merging_; }; -bool CancelSplitCat::horizontallyMergeable( - const std::vector& frontier, - int64_t& split_axis) { - NVF_ERROR(!frontier.empty()); - - // Check all Exprs in `frontier` - // 1. have the same op type and attributes, - // 2. transform IDs in the same way, and - // 3. don't resize the split axis. - - if (std::adjacent_find( - frontier.begin(), frontier.end(), [](Expr* lhs, Expr* rhs) { - return !lhs->sameOp(rhs); - }) != frontier.end()) { - return false; +bool sameOp(const std::vector& frontier) { + return std::adjacent_find( + frontier.begin(), frontier.end(), [](Expr* lhs, Expr* rhs) { + return !lhs->sameOp(rhs); + }) == frontier.end(); +} + +bool CancelSplitCat::sameIterDomainTransforms( + const std::vector& slices, + const std::vector& pads, + const int64_t cat_axis) { + NVF_ERROR(slices.size() == pads.size()); + NVF_ERROR(!slices.empty()); + + ValGraph& exact_graph = id_model_for_merging_.idGraph(IdMappingMode::EXACT); + { + // Map pads[i0].root[cat_axis] and pads[i1].root[cat_axis]. Other axes were + // already mapped due to the `cat` when the IdModel was built. + const std::vector& first_root = + pads[0]->out()->as()->getRootDomain(); + for (size_t i = 1; i < pads.size(); i++) { + const std::vector& other_root = + pads[i]->out()->as()->getRootDomain(); + NVF_ERROR(first_root.size() == other_root.size()); + exact_graph.mapVals(first_root[cat_axis], other_root[cat_axis]); + } } - if (auto* set = dynamic_cast(frontier[0])) { - if (set->opType() == LoadStoreOpType::Set) { - auto* set_out = set->out()->as(); - std::optional> permutation = - ir_utils::computePermutation( - set_out->getRootDomain(), set_out->getMaybeRFactorDomain()); - if (!permutation.has_value()) { + // The above code block only maps IterDomains across chains. If a self mapping + // is detected at this point, it's likely due to some IterDomains are permuted + // diffrently between two chains. See + // MoveSplitCatTest.Noncancellable_PermutedDifferently for an example. + for (auto* slice : slices) { + if (id_model_for_merging_.hasSelfMapping(slice->out())) { + return false; + } + } + + { + const std::vector& first_rfactor = + slices[0]->out()->getMaybeRFactorDomain(); + size_t num_dims = first_rfactor.size(); + for (size_t i = 1; i < slices.size(); i++) { + const std::vector& other_rfactor = + slices[i]->out()->getMaybeRFactorDomain(); + if (other_rfactor.size() != num_dims) { return false; } - - for (size_t i = 1; i < frontier.size(); i++) { - auto* other_set_out = - frontier[i]->as()->out()->as(); - std::optional> other_permutation = - ir_utils::computePermutation( - other_set_out->getRootDomain(), - other_set_out->getMaybeRFactorDomain()); - if (!other_permutation.has_value()) { - return false; - } - if (*permutation != *other_permutation) { + for (size_t j = 0; j < num_dims; j++) { + if (!exact_graph.disjointValSets().strictAreMapped( + first_rfactor[j], other_rfactor[j])) { return false; } } - - split_axis = (*permutation)[split_axis]; - return true; } + return true; } - - return false; } -// If `exprs` are `SliceOp`s that form a split, returns the base tensor of the +// If `slices` form a split, returns the base tensor of the // split. Returns null otherwise. -TensorView* exprsFormSplit( - const std::vector& exprs, +TensorView* slicesFormSplit( + const std::vector& slices, const int64_t split_axis) { // Checks that all exprs are slices and are based on the // same tensor. Otherwise, they don't form a split. TensorView* split_in = nullptr; - for (Expr* e : exprs) { - auto* slice = dynamic_cast(e); - if (slice == nullptr) { - return nullptr; - } - + for (auto* slice : slices) { if (split_in == nullptr) { split_in = slice->in(); } else if (split_in != slice->in()) { @@ -144,9 +234,8 @@ TensorView* exprsFormSplit( // // `split_ranges[i]` is the slice range of `exprs[i]` for the split axis. std::vector split_ranges; - split_ranges.reserve(exprs.size()); - for (auto i : c10::irange(exprs.size())) { - auto* slice = exprs[i]->as(); + split_ranges.reserve(slices.size()); + for (auto* slice : slices) { const std::vector& slice_ranges = slice->getRanges(); // Check the steps are all one. if (std::any_of( @@ -179,8 +268,7 @@ TensorView* exprsFormSplit( // Due to the limitation of `sameAs` mentioned in #1859, I can't check // split_ranges.back().stop is the same as the dimension size. Below is a // slightly lengthy workaround. - if (!exprs.back() - ->as() + if (!slices.back() ->out() ->getMaybeRFactorDomain()[split_axis] ->definition() @@ -189,7 +277,7 @@ TensorView* exprsFormSplit( ->isZero()) { return nullptr; } - for (size_t i = 0; i + 1 < exprs.size(); i++) { + for (size_t i = 0; i + 1 < slices.size(); i++) { if (!split_ranges[i].stop->sameAs(split_ranges[i + 1].start)) { return nullptr; } @@ -198,30 +286,102 @@ TensorView* exprsFormSplit( return split_in; } -TensorView* CancelSplitCat::findCancelingSplit( - CatOp* cat, - std::vector& use_def_chain) { +int64_t CancelSplitCat::computeCatAxisAfterZipping( + const std::vector& slice_rfactor, + IterDomain* cat_id) { + const ValGraph& exact_graph = id_model_.idGraph(IdMappingMode::EXACT); + ValGroup cat_group = exact_graph.toGroup(cat_id); + while (cat_group != nullptr) { + // If `cat_group` contains a slice rfactor ID, return the index of that ID. + auto i = std::find_if( + slice_rfactor.begin(), slice_rfactor.end(), [&](IterDomain* id) { + return exact_graph.toGroup(id) == cat_group; + }); + if (i != slice_rfactor.end()) { + return i - slice_rfactor.begin(); + } + + // Conceptually zip `cat_group` over its definition. + auto cat_group_after_zipping = [&](ValGroup cat_group) -> ValGroup { + const ExprGroups& defining_groups = exact_graph.getDefinitions(cat_group); + if (defining_groups.size() != 1) { + return nullptr; + } + ExprGroup defining_group = defining_groups.front(); + // Pick an arbitrary Expr from defining_group as the representative. + Expr* def = defining_group->front(); + + if (Split* split = dynamic_cast(def)) { + if (exact_graph.toGroup(split->outer()) == cat_group) { + return exact_graph.toGroup(split->in()); + } + return nullptr; + } + + if (Merge* merge = dynamic_cast(def)) { + return exact_graph.toGroup(merge->outer()); + } + + return nullptr; + }; + cat_group = cat_group_after_zipping(cat_group); + } + + return -1; +} + +// Finds the pairing split of `cat` by traversing the use-def chains. If found, +// returns the slices of the pairing split and `cat`'s preceding `PadOp`s. This +// function does some basic checks like: +// 1. Ops between the chains must have the same op type and attributes. +// 2. Chains must end with slices. +// However, these checks are necessary but not sufficient to guarantee the +// pairing split is canceling. To make that decision, the caller has to further +// inspect the ops in between. +std::optional, std::vector>> +findPairingSplit(CatOp* cat) { NVF_CHECK(!cat->inputs().empty(), "`cat` has zero inputs: ", cat); - // `frontier` initially contains the preceding Exprs of the `PadOp`s. Then, we + // `PadOp`s that produce `cat`'s inputs. + std::vector pads; + pads.reserve(cat->inputs().size()); + + // `frontier` initially contains the `Expr`s that precede `pads`. Then, we // repeatedly try to move the frontier up in lockstep as long as Exprs in the // frontier can be horizontally merged and applied on the unsplit tensor. std::vector frontier; frontier.reserve(cat->inputs().size()); for (Val* in : cat->inputs()) { auto* pad = in->definition()->as(); + pads.push_back(pad); frontier.push_back(pad->in()->definition()); } // Exit the loop when any Expr in `frontier` is a slice or a null. - int64_t split_axis = cat->concatenatedDim(); while (std::none_of(frontier.begin(), frontier.end(), [](Expr* e) { return e == nullptr || e->isA(); })) { - if (!horizontallyMergeable(frontier, std::ref(split_axis))) { - return nullptr; + if (!sameOp(frontier)) { + return std::nullopt; + } + + // We can probably extend this list to include many other unary ops. + // Currently, I limit this to only reshapes and permutes to reduce blast + // radius. + auto supported = [](Expr* e) -> bool { + if (e->isA()) { + return true; + } + if (auto* set = dynamic_cast(e)) { + if (set->opType() == LoadStoreOpType::Set) { + return true; + } + } + return false; + }; + if (!supported(frontier[0])) { + return std::nullopt; } - use_def_chain.push_back(frontier[0]); // Advance the frontier in lockstep. for (Expr*& e : frontier) { @@ -233,23 +393,65 @@ TensorView* CancelSplitCat::findCancelingSplit( } } - TensorView* split_in = exprsFormSplit(frontier, split_axis); + std::vector slices; + slices.reserve(frontier.size()); + for (Expr* e : frontier) { + auto* slice = dynamic_cast(e); + if (slice == nullptr) { + return std::nullopt; + } + slices.push_back(slice); + } + + return std::make_pair(slices, pads); +} + +TensorView* CancelSplitCat::findCancelingSplit( + CatOp* cat, + std::vector& def_use_chain) { + auto heads_and_tails = findPairingSplit(cat); + if (!heads_and_tails.has_value()) { + return nullptr; + } + std::vector slices; + std::vector pads; + std::tie(slices, pads) = *heads_and_tails; + + int64_t cat_axis = cat->concatenatedDim(); + if (!sameIterDomainTransforms(slices, pads, cat_axis)) { + return nullptr; + } + + cat_axis = computeCatAxisAfterZipping( + slices[0]->out()->getMaybeRFactorDomain(), + pads[0]->out()->as()->getRootDomain()[cat_axis]); + if (cat_axis == -1) { + return nullptr; + } + + TensorView* split_in = slicesFormSplit(slices, cat_axis); + if (split_in == nullptr) { + return nullptr; + } + + std::vector first_chain = + StmtSort::getExprsBetween({slices[0]->out()}, {pads[0]->in()}); + def_use_chain.swap(first_chain); return split_in; } void CancelSplitCat::run() { std::vector exprs = fusion_->exprs(); for (auto* cat : ir_utils::filterByType(exprs)) { - std::vector use_def_chain; - TensorView* split_in = findCancelingSplit(cat, std::ref(use_def_chain)); + std::vector def_use_chain; + TensorView* split_in = findCancelingSplit(cat, std::ref(def_use_chain)); if (split_in == nullptr) { continue; } Val* merged_out = split_in; - for (auto i = use_def_chain.rbegin(), end = use_def_chain.rend(); i != end; - i++) { - Expr* merged = replayExprWithNewInput(*i, merged_out); + for (Expr* e : def_use_chain) { + Expr* merged = replayExprWithNewInput(e, merged_out); NVF_ERROR( merged->outputs().size() == 1, "Currently, we merge only unary ops, so it would be a programming " diff --git a/test/test_move_split_cat.cpp b/test/test_move_split_cat.cpp index 8f8bd6a16ff..fe4ccb32063 100644 --- a/test/test_move_split_cat.cpp +++ b/test/test_move_split_cat.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include #include #include @@ -17,10 +18,12 @@ namespace nvfuser { using testing::Contains; +using testing::IsTrue; +using testing::Property; using MoveSplitCatTest = NVFuserTest; -TEST_F(MoveSplitCatTest, Cancellable_Adjacent) { +TEST_F(MoveSplitCatTest, Cancellable_SplitImmediatelyFollowedByCat) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -43,6 +46,52 @@ TEST_F(MoveSplitCatTest, Cancellable_Adjacent) { EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); } +TEST_F(MoveSplitCatTest, Noncancellable_DifferentOrder) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({2, 6}); + TensorView* s0 = slice(in, {0, 0}, {2, 3}); + TensorView* s1 = slice(in, {0, 3}, {2, 6}); + TensorView* out = cat({s1, s0}, /*dim=*/-1); + + fusion->addInput(in); + fusion->addOutput(out); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({2, 6}, options); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_FALSE(out_tensors[0].is_alias_of(in_tensor)); +} + +TEST_F(MoveSplitCatTest, Cancellable_SetWithoutPermute) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({2, 5}); + TensorView* s0 = slice(in, {0, 0}, {2, 2}); + TensorView* s1 = slice(in, {0, 2}, {2, 5}); + s0 = set(s0); + s1 = set(s1); + TensorView* out = cat({s0, s1}, /*dim=*/-1); + + fusion->addInput(in); + fusion->addOutput(out); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({2, 5}, options); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); +} + TEST_F(MoveSplitCatTest, Noncancellable_SliceAmountAndPaddingAmountMismatch) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -118,6 +167,7 @@ TEST_F(MoveSplitCatTest, Cancellable_PermuteInBetween) { EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); } +namespace { MATCHER(IsPermute, "") { if (auto* set = dynamic_cast(arg)) { if (auto* set_out = dynamic_cast(set->out())) { @@ -126,6 +176,7 @@ MATCHER(IsPermute, "") { } return false; } +} // namespace TEST_F(MoveSplitCatTest, Cancellable_IncompatibleAllocationOrder) { auto fusion = std::make_unique(); @@ -151,11 +202,8 @@ TEST_F(MoveSplitCatTest, Cancellable_IncompatibleAllocationOrder) { // Check the two permutes are merged to one. FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); - ASSERT_EQ(runtime->executors().size(), 1) - << "After merging, the whole fusion can be scheduled unsegmented."; - const FusionExecutor& executor = runtime->executors().front(); - kir::Kernel* kernel = executor.kernel(); - EXPECT_THAT(kernel->exprs(), Contains(IsPermute()).Times(1)); + Fusion* complete_fusion = runtime->fusionSegments()->completeFusion(); + EXPECT_THAT(complete_fusion->exprs(), Contains(IsPermute()).Times(1)); // Due to the incompatible output allocation order, the output can't be an // alias. @@ -250,20 +298,22 @@ TEST_F(MoveSplitCatTest, Noncancellable_PermutedDifferently) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - TensorView* in = makeContigConcreteTensor({2, 2, 2, 10}); - TensorView* s0 = slice(in, {0, 0, 0, 0}, {2, 2, 2, 2}); - TensorView* s1 = slice(in, {0, 0, 0, 2}, {2, 2, 2, 5}); - TensorView* s2 = slice(in, {0, 0, 0, 5}, {2, 2, 2, 10}); - s0 = permute(s0, {2, 1, 0, 3}); - s1 = permute(s1, {1, 0, 2, 3}); - s2 = permute(s2, {2, 1, 0, 3}); - TensorView* out = cat({s0, s1, s2}, /*dim=*/-1); + TensorView* in = makeContigConcreteTensor({4, 2}); + TensorView* s0 = slice(in, {0, 0}, {2, 2}); + s0 = set(s0); + s0 = reshape(s0, {2, 2}, {4}); + + TensorView* s1 = slice(in, {2, 0}, {4, 2}); + s1 = permute(s1, {1, 0}); + s1 = reshape(s1, {2, 2}, {4}); + + TensorView* out = cat({s0, s1}, /*dim=*/0); fusion->addInput(in); fusion->addOutput(out); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor in_tensor = at::randn({2, 2, 2, 10}, options); + at::Tensor in_tensor = at::randn({4, 2}, options); FusionExecutorCache fec(std::move(fusion)); auto out_tensors = fec.runFusionWithInputs({in_tensor}); @@ -299,4 +349,168 @@ TEST_F(MoveSplitCatTest, Noncancellable_UnsupportedOps) { EXPECT_FALSE(out_tensors[0].is_alias_of(in_tensor)); } +TEST_F(MoveSplitCatTest, Cancellable_ReshapeInBetween) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({4, 10}); + TensorView* s0 = slice(in, {0, 0}, {4, 2}); + TensorView* s1 = slice(in, {0, 2}, {4, 5}); + TensorView* s2 = slice(in, {0, 5}, {4, 10}); + s0 = reshape(s0, {4, 2}, {2, 2, 2}); + s1 = reshape(s1, {4, 3}, {2, 2, 3}); + s2 = reshape(s2, {4, 5}, {2, 2, 5}); + TensorView* out = cat({s0, s1, s2}, /*dim=*/-1); + + fusion->addInput(in); + fusion->addOutput(out); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({4, 10}, options); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); +} + +TEST_F(MoveSplitCatTest, Cancellable_ReshapeAndPermuteInBetween) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({6, 10}); + TensorView* s0 = slice(in, {0, 0}, {6, 2}); + TensorView* s1 = slice(in, {0, 2}, {6, 5}); + TensorView* s2 = slice(in, {0, 5}, {6, 10}); + s0 = reshape(s0, {6, 2}, {2, 3, 2}); + s1 = reshape(s1, {6, 3}, {2, 3, 3}); + s2 = reshape(s2, {6, 5}, {2, 3, 5}); + s0 = permute(s0, {1, 0, 2}); + s1 = permute(s1, {1, 0, 2}); + s2 = permute(s2, {1, 0, 2}); + TensorView* out = cat({s0, s1, s2}, /*dim=*/-1); + + fusion->addInput(in); + fusion->addOutput(out); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({6, 10}, options); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); +} + +TEST_F(MoveSplitCatTest, Cancellable_Issue1768) { + constexpr int b = 16; // batch size + constexpr int h = 12; // number of heads + constexpr int s = 128; // sequence length + constexpr int f = 64; // feature size per head + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* sdpa_backward_out = + makeContigConcreteTensor({b, h * 3, s, f}, DataType::Half); + sdpa_backward_out->setAllocationDomain( + {sdpa_backward_out->axis(0), + sdpa_backward_out->axis(2), + sdpa_backward_out->axis(1), + sdpa_backward_out->axis(3)}, + true); + TensorView* dq = slice(sdpa_backward_out, {0, 0, 0, 0}, {b, h, s, f}); + TensorView* dk = slice(sdpa_backward_out, {0, h, 0, 0}, {b, h * 2, s, f}); + TensorView* dv = slice(sdpa_backward_out, {0, h * 2, 0, 0}, {b, h * 3, s, f}); + // Swap the head dimension and the sequence length dimension. + dq = permute(dq, {0, 2, 1, 3}); + dk = permute(dk, {0, 2, 1, 3}); + dv = permute(dv, {0, 2, 1, 3}); + dq = reshape(dq, {b, s, h, f}, {b, s, h * f}); + dk = reshape(dk, {b, s, h, f}, {b, s, h * f}); + dv = reshape(dv, {b, s, h, f}, {b, s, h * f}); + TensorView* cat_out = cat({dq, dk, dv}, /*dim=*/-1); + TensorView* sum_out = castOp(DataType::Float, cat_out); + sum_out = sum(sum_out, {0, 1}); + sum_out = castOp(DataType::Half, sum_out); + TensorView* view_out = + reshape(cat_out, {b, s, h * f * 3}, {b * s, h * f * 3}); + TensorView* permute_out = permute(view_out, {1, 0}); + + fusion->addInput(sdpa_backward_out); + fusion->addOutput(sum_out); + fusion->addOutput(view_out); + fusion->addOutput(permute_out); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor in_tensor = + at::randn({b * h * 3 * s * f}, options) + .as_strided({b, h * 3, s, f}, {h * 3 * s * f, f, h * 3 * f, 1}); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_TRUE(out_tensors[1].is_alias_of(in_tensor)); + EXPECT_TRUE(out_tensors[2].is_alias_of(in_tensor)); +} + +TEST_F(MoveSplitCatTest, MultiplePairs) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* merged = makeContigConcreteTensor({4, 6}); + fusion->addInput(merged); + + // Region 0. Mergeable because both slices are permuted in the same way and + // the cat axis matches the split axis. + TensorView* s0 = slice(merged, {0, 0}, {2, 6}); + TensorView* s1 = slice(merged, {2, 0}, {4, 6}); + s0 = permute(s0, {1, 0}); + s1 = permute(s1, {1, 0}); + merged = cat({s0, s1}, /*dim=*/1); + + // Region 1. Not mergeable because the outer dimension is split and the inner + // dimension is catted. + s0 = slice(merged, {0, 0}, {3, 4}); + s1 = slice(merged, {3, 0}, {6, 4}); + s0 = reshape(s0, {3, 4}, {6, 2}); + s1 = reshape(s1, {3, 4}, {6, 2}); + merged = cat({s0, s1}, /*dim=*/1); + + // Region 2. Mergeable because both slices are reshaped in the same way and + // the outer dimension is split and catted. + s0 = slice(merged, {0, 0}, {3, 4}); + s1 = slice(merged, {3, 0}, {6, 4}); + s0 = reshape(s0, {3, 4}, {6, 2}); + s1 = reshape(s1, {3, 4}, {6, 2}); + merged = cat({s0, s1}, /*dim=*/0); + + fusion->addOutput(merged); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({4, 6}, options); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); + Fusion* complete_fusion = runtime->fusionSegments()->completeFusion(); + std::vector exprs = complete_fusion->exprs(); + + // Only region 1 is not mergeable, so we expect to see only that region + // contains two slices and one cat in the pre-segmenter fusion. + EXPECT_THAT( + exprs, Contains(Property(&Expr::isA, IsTrue())).Times(2)); + EXPECT_THAT(exprs, Contains(Property(&Expr::isA, IsTrue())).Times(1)); + // The two permutes in region 0 are expected to be merged. + EXPECT_THAT(exprs, Contains(IsPermute()).Times(1)); + // The two reshapes in region 1 stay as is and the two reshapes in region 2 + // are merged. Therefore, three reshapes in total. + EXPECT_THAT(exprs, Contains(Property(&Expr::isA, IsTrue())).Times(3)); +} + } // namespace nvfuser