diff --git a/CMakeLists.txt b/CMakeLists.txt index 13b3642c3bf..6fd5cda97c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -210,6 +210,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp ${NVFUSER_SRCS_DIR}/val_graph.cpp + ${NVFUSER_SRCS_DIR}/val_graph_visitor.cpp ) # We don't link CUPTI for MSVC diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 25f4c183af0..d1af78fef7f 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -242,6 +242,14 @@ class VectorOfUniqueEntries { return vector_.end(); } + T& at(size_t pos) { + return vector_.at(pos); + } + + const T& at(size_t pos) const { + return vector_.at(pos); + } + std::string toString() const { std::stringstream ss; ss << "{ "; diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 3e9324a0c19..f5578a76648 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -99,6 +99,58 @@ std::vector ValGraph::inputGroups(const ExprGroup& expr) const { return input_groups; } +ValGroups ValGraph::getTerminatingInputs() const { + // Initialize vals to traverse + ValGroups all_vals{ + disjointValSets().disjointSets().begin(), + disjointValSets().disjointSets().end()}; + + // Initialize exprs to traverse + ExprGroups all_exprs{ + disjointExprSets().disjointSets().begin(), + disjointExprSets().disjointSets().end()}; + + // Grab all vals that are not input, i.e., having a defining expr + // within all_exprs. + // + // Note that an input Val group may be mapped with an output + // group. For example, the AlmostExact graph maps an input of split + // with the outer output if the split factor is one. Such a Val + // group is considered a terminating input as long as the input has + // no defining expression. This is for the use case of + // ValGraphVisitor. + // + // Example: + // + // [i0, i1] + // split by 1 + // [i0/1, 1, i1] + // merge + // [i0/1, 1*i1] + // + // Here, i0 and i0/1 would create a Val group of {i0, i0/1} in the + // AlmostExact graph. This group has a defining expression of the + // split, but since it's a cyclic dependency, we ignore the + // expression and consider the Val group a terminating input. + + ValGroups not_inputs; + for (const ExprGroup& expr_group : all_exprs) { + const std::vector input_groups = inputGroups(expr_group); + const std::vector output_groups = outputGroups(expr_group); + std::unordered_set input_set{ + input_groups.begin(), input_groups.end()}; + + for (const ValGroup& output_group : output_groups) { + if (input_set.count(output_group)) { + continue; + } + not_inputs.pushBack(output_group); + } + } + + return all_vals.computeSubtract(not_inputs); +} + ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { DequeOfExprGroup to_visit; for (const ValGroup& of_val_group : of) { diff --git a/csrc/val_graph.h b/csrc/val_graph.h index f7a20e46bdb..72182e3f2f8 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -95,6 +95,9 @@ class ValGraph { std::vector outputGroups(const ExprGroup& expr) const; std::vector inputGroups(const ExprGroup& expr) const; + // Return Val groups that have no definition. + ValGroups getTerminatingInputs() const; + // Recursively traverses uses of the IdGroups in 'of' and returns all // ExprGroups that have a use in their definition of provided of IdGroups. ExprGroups allUsesOf(const ValGroups& of) const; diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp new file mode 100644 index 00000000000..7d6646c747e --- /dev/null +++ b/csrc/val_graph_visitor.cpp @@ -0,0 +1,152 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include + +namespace nvfuser { + +void ValGraphVisitor::traverse() { + const ValGroups terminating_inputs = graph().getTerminatingInputs(); + std::deque to_visit_vals( + terminating_inputs.begin(), terminating_inputs.end()); + ValGroups visited_vals; + + std::deque to_visit_exprs; + ExprGroups visited_exprs; + + auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool { + const auto inp_groups = graph().inputGroups(expr_group); + return std::all_of( + inp_groups.begin(), inp_groups.end(), [&](ValGroup val_group) { + return visited_vals.has(val_group) || val_group->empty(); + }); + }; + + // If any input of the def expr is mapped with the val + // group itself, i.e., a trivial expr, allow visiting the + // val group first. The trivial expr group will be visited + // after the val group. + // + // Example: + // + // [i0, 1] + // merge + // [i0*1] + // map i0 and i0*1 + // ValGroups: {{i0, i0*1}, {1}} + // + // Then, {i0, i0*1} and {1} would be visited first, then the merge + // expr group would be visited. {i0, i0*1} is also an output group + // of the merge but since it's already in the visited set, it would + // not be visited again. + // + // See also IdModelTest.ValGraphStmtSort3 for a concrete example. + auto is_val_ready = [&](const ValGroup& val_group) -> bool { + if (terminating_inputs.has(val_group)) { + return true; + } + const ExprGroups& unique_defs = graph().getDefinitions(val_group); + return std::all_of( + unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { + if (expr_group->empty() || visited_exprs.has(expr_group)) { + return true; + } + // Handle ExprGroups that return one or some of its input ValGroups as + // output. This expr_group is not visited yet, which means there're + // input ValGroups that are not yet visited. If those not-visited + // inputs are actually the same as val_group, visit val_group at this + // point to resolve the circular dependency. + for (const ValGroup& input_group : graph().inputGroups(expr_group)) { + if (input_group != val_group && !visited_vals.has(input_group) && + input_group->empty()) { + return false; + } + } + return true; + }); + }; + + // Detect if nothing has been processed which would put us in an infinite + // loop + bool something_was_processed = false; + + do { + something_was_processed = false; + + // Process expressions first as all definitions of vals have to be + // processed before we can process that val. + + while (!to_visit_exprs.empty()) { + ExprGroup current_expr_group = to_visit_exprs.front(); + to_visit_exprs.pop_front(); + NVF_ERROR(!current_expr_group->empty()); + if (visited_exprs.has(current_expr_group)) { + continue; + } + + if (is_expr_ready(current_expr_group)) { + handle(current_expr_group); + + something_was_processed = true; + visited_exprs.pushBack(current_expr_group); + + for (const ValGroup& output_group : + graph().outputGroups(current_expr_group)) { + to_visit_vals.push_back(output_group); + } + } + } + + std::deque still_to_visit_vals; + while (!to_visit_vals.empty()) { + auto current_val_group = to_visit_vals.front(); + to_visit_vals.pop_front(); + NVF_ERROR(!current_val_group->empty()); + if (visited_vals.has(current_val_group)) { + continue; + } + + if (is_val_ready(current_val_group)) { + handle(current_val_group); + + something_was_processed = true; + visited_vals.pushBack(current_val_group); + + for (const ExprGroup& use_group : graph().getUses(current_val_group)) { + to_visit_exprs.push_back(use_group); + } + } else { + still_to_visit_vals.push_back(current_val_group); + } + } + + std::swap(to_visit_vals, still_to_visit_vals); + + } while (something_was_processed); + + if (!to_visit_vals.empty()) { + std::stringstream ss; + ss << "The graph has an infinite loop. The following Vals should be visited but are never ready:"; + for (const ValGroup& vg : to_visit_vals) { + ss << " " << nvfuser::toString(vg); + } + NVF_ERROR(false, ss.str()); + } + + if (!to_visit_exprs.empty()) { + std::stringstream ss; + ss << "The graph has an infinite loop. The following Exprs should be visited but are never ready:"; + for (const ExprGroup& eg : to_visit_exprs) { + ss << " " << nvfuser::toString(eg); + } + NVF_ERROR(false, ss.str()); + } +} + +} // namespace nvfuser diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h new file mode 100644 index 00000000000..54a5ed253b8 --- /dev/null +++ b/csrc/val_graph_visitor.h @@ -0,0 +1,121 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include + +namespace nvfuser { + +// Iterates through a Val Graph in topological order, calling handle on +// all Val and all Expr groups in a forward topological order. +// +// Warning: A ValGraph is not guaranteed to be a DAG. In fact, the +// AlmostExact and Permissive graphs would have cycles with a ValGroup +// and an ExprGroup. For example: +// +// [i0, 1] +// merge +// [i0*1] +// Current ValGroups: {{i0}, {1}, {i0*1}} +// map i0 and i0*1 as they effectively have the same extent +// Final ValGroups: {{i0, i0*1}, {1}} +// +// Here, the merge expr is the user of i0 and the definition of +// i0*1. Since i0 and i0*1 are mapped, the dependency chain looks +// like: +// +// {i0, i0*1} ----> {merge} ----> {i0, i0*1} +// use def +// +// These ExprGroups are called trivial ExprGroups (see also +// ValGraph::isTrivialExprGroup). +// +// Strictly speaking, these cycles mean there's no valid topological +// order anymore. In our use cases for IdModel, however, it's likely +// sufficient to return an ordering such as: +// +// {i0, i0*1} -> {merge} +// +// I.e., we visit {i0, i0*1} first even though {merge} is technically +// a definition. +// +// Another alternative may be simply giving up when such a cycle is +// detected, which may be more preferrable as it would be less +// confusing. At this moment, this visitor is only used with graphs +// with no such cycle. Should be revisited when necessary. +// +// Warning: This is not a great iterator if there's a desire to minimize paths +// traveled to simply visit all ValGroups in order. See ExprsBetween to see how +// we might minimize paths. +class ValGraphVisitor { + public: + ValGraphVisitor() = delete; + + ValGraphVisitor& operator=(const ValGraphVisitor& other) = delete; + + ValGraphVisitor& operator=(ValGraphVisitor&& other) = delete; + + virtual ~ValGraphVisitor() = default; + + protected: + ValGraphVisitor(const ValGraph& val_graph) : val_graph_(val_graph) {} + + ValGraphVisitor(const ValGraphVisitor& other) = default; + + ValGraphVisitor(ValGraphVisitor&& other) = default; + + virtual void handle(const ValGroup& val_group) = 0; + virtual void handle(const ExprGroup& expr_group) = 0; + + void traverse(); + + const ValGraph& graph() { + return val_graph_; + }; + + private: + const ValGraph& val_graph_; +}; + +// Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor. +class ValGraphStmtSort : public ValGraphVisitor { + public: + ValGraphStmtSort(const ValGraph& val_graph) : ValGraphVisitor(val_graph) { + ValGraphVisitor::traverse(); + } + + // Return non-reference so that code like below can work + // for (auto expr_group: ValGraphStmtSort(graph).exprs()) + ExprGroups exprs() const { + return sorted_exprs_; + } + + ValGroups vals() const { + return sorted_vals_; + } + + ~ValGraphStmtSort() override = default; + + protected: + using ValGraphVisitor::handle; + + void handle(const ValGroup& val_group) override { + sorted_vals_.pushBack(val_group); + } + + void handle(const ExprGroup& expr_group) override { + sorted_exprs_.pushBack(expr_group); + } + + ExprGroups sorted_exprs_; + ValGroups sorted_vals_; +}; + +} // namespace nvfuser diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index fc823316564..d101ca6cabe 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace nvfuser { @@ -37,4 +38,350 @@ TEST_F(IdModelTest, DetectSelfMapping) { ::testing::HasSubstr("!hasSelfMapping"))); } +namespace { + +// Get n-th parent expr traversing through the first input of each +// parent +Expr* getParentExpr(Val* val, int n) { + for (int i = 0; i < n - 1; ++i) { + NVF_ERROR(val->definition() != nullptr); + val = val->definition()->input(0); + } + NVF_ERROR(val->definition() != nullptr); + return val->definition(); +}; + +TensorView* getTensorByName( + const std::vector& tvs, + StmtNameType name) { + if (auto it = std::find_if( + tvs.begin(), + tvs.end(), + [&](TensorView* tv) { return tv->name() == name; }); + it != tvs.end()) { + return *it; + } else { + return nullptr; + } +} + +// Create a fusion where we're missing a valid concrete id so the compute at map +// processing will fail. We need to be able to create the concrete ID not just +// look for one. It is not yet possible to lower this fusion as the +// current indexing cannot generate correct indices. Also used in +// FusionIndeixing19 as well as Example 2 in the design doc about Loop +// Promotion Analysis. +std::unique_ptr createFusionWithMultipleResolutionPaths() { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({7}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = broadcast(tv1, {false, true}); + + auto tv3 = makeConcreteTensor({7, 11}); + fusion.addInput(tv3); + + auto tv4 = add(tv3, tv2); + auto tv5 = broadcast(tv4, {false, false, true}); + // tv4[7, 11, 1] + + auto tv6 = broadcast(tv1, {false, true}); + + auto tv7 = makeConcreteTensor({7, 13}); + fusion.addInput(tv7); + auto tv8 = add(tv7, tv6); + auto tv9 = broadcast(tv8, {false, true, false}); + // tv9[7, 1, 13] + + auto tv10 = add(tv5, tv9); + fusion.addOutput(tv10); + + // tv10[7, 11, 13] + tv10->merge(0)->merge(0); + // tv10[7*11*13] + tv10->split(0, 5)->split(0, 3); + // tv10[7*11*13//5//3, 3, 5] + + TransformPropagatorWithCheck propagator(tv10); + MaxRootDomainInfoSpanningTree(tv10).traverse(&propagator); + + std::vector tensors_to_inline{tv1, tv2, tv4, tv6, tv8}; + for (auto tensor : tensors_to_inline) { + tensor->inlineAt(1); + } + + return fusion_ptr; +} + +// Check the results of ValGraphStmtSort. Only the ordering of +// ExprGroups is checked for now as it's likely sufficient. +// +// ref_order: The order must be exactly the +// same as indicated by this list. While there can be different +// order that still satisfy the topologial ordering, we also need +// deterministic ordering, so the results should be always the same. +void checkSortingResults( + const ValGraph& graph, + const ExprGroups& sorted_expr_groups, + const ValGroups& sorted_val_groups, + const std::vector& ref_order) { + // Make sure sorted_val_groups cover all Expr groups + const std::unordered_set& ref_expr_group_set{ + graph.disjointExprSets().disjointSets().begin(), + graph.disjointExprSets().disjointSets().end()}; + std::unordered_set sorted_expr_group_set{ + sorted_expr_groups.begin(), sorted_expr_groups.end()}; + ASSERT_EQ(sorted_expr_group_set, ref_expr_group_set) + << "Mismatched ExprGroups."; + + // Make sure sorted_val_groups covers all Val groups + const std::unordered_set& ref_val_group_set{ + graph.disjointValSets().disjointSets().begin(), + graph.disjointValSets().disjointSets().end()}; + std::unordered_set sorted_val_group_set{ + sorted_val_groups.begin(), sorted_val_groups.end()}; + ASSERT_EQ(sorted_val_group_set, ref_val_group_set) << "Mismatched ValGroups."; + + // Check the ordering + ASSERT_EQ(sorted_expr_groups.size(), ref_order.size()); + for (const auto i : c10::irange(ref_order.size())) { + Expr* ref_expr = ref_order.at(i); + const ExprGroup& eg = sorted_expr_groups.at(i); + ASSERT_TRUE(eg->has(ref_expr)) + << "Mismatch detected at " << i << "-th expr group. " + << "Expected: " << nvfuser::toString(graph.toGroup(ref_expr)) << ", " + << ref_expr->toString() << ". Actual: " << nvfuser::toString(eg) << ", " + << eg->front()->toString(); + } +} + +} // namespace + +// Sorting test with a trivial fusion +TEST_F(IdModelTest, ValGraphStmtSort1) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + // No ID expr yet. checkSortingResults validates the exprssion + // order, but since there's no expr, it just makes sure exprs() and + // vals() return all the val and expr groups. + { + IdModel id_model(&fusion); + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), {}); + } + + // Add ID exprs. Just apply a merge-and-split pattern to all + // tensors. + tv2->merge(0)->split(0, 4); + TransformPropagator propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + // The exact graph should just map all IDs of the tensors. Ther + // ordering of the exprs should be the merge and then the split. + { + IdModel id_model(&fusion); + + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + + // Reference expr order: merge, split + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + + checkSortingResults( + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); + } +} + +// Sorting test wth a disconnected graph +TEST_F(IdModelTest, ValGraphStmtSort2) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + // Note that the two groups of tensors, {tv0, tv1} and {tv2, tv3}, + // are not connected + + for (auto tv : ir_utils::allTvs(&fusion)) { + tv->merge(0)->split(0, 4); + } + + // Since the two tensors are disconnected, there's no ordering + // between the ID exprs of the two tensor groups. So, the correct + // ordering should have the merge exprs before the split exprs, but + // there's no order between the tv1 and tv3 exprs. For example, + // these are all valid: + // + // tv1 merge -> tv3 merge -> tv1 split -> tv3 split + // tv1 merge -> tv1 split -> tv3 merge -> tv3 split + // tv3 merge -> tv3 split -> tv1 merge -> tv1 split + // tv3 merge -> tv1 merge -> tv3 split -> tv1 split + // + // Here, the actual order returned by ValGraphStmtSort is the first + // one. Since it should be deterministic, we check if the returned + // expr vector is indeed ordered that way. + + IdModel id_model(&fusion); + + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + + std::vector ref_order; + ref_order.push_back(getParentExpr(tv1->axis(0), 2)); + ref_order.push_back(getParentExpr(tv3->axis(0), 2)); + ref_order.push_back(getParentExpr(tv1->axis(0), 1)); + ref_order.push_back(getParentExpr(tv3->axis(0), 1)); + + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + +// Sorting with trivial ExprGroup, i.e., ExprGroup whose input and +// output are mapped as the same ValGroup. It's effectively a cyclic +// dependency and the graph is no longer a DAG. +TEST_F(IdModelTest, ValGraphStmtSort3) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + auto tv3 = makeSymbolicTensor(2); + fusion.addInput(tv3); + auto tv4 = set(tv3); + fusion.addOutput(tv4); + + // Merge and split by one. The split input and output will be mapped. + for (auto tv : {tv0, tv1, tv2}) { + tv->merge(0)->split(0, 1); + } + + // Also test an isolated trivial expr. Note that tv3 and tv4 are not + // connected with tv0, tv1 and tv2. + tv4->merge(0)->split(0, 1); + + IdModel id_model(&fusion); + ValGraph vg = id_model.idGraph(IdMappingMode::EXACT); + + // Map the split-by-1 input and output + vg.mapVals(tv2->axis(0), tv2->axis(0)->definition()->input(0)); + vg.mapVals(tv4->axis(0), tv4->axis(0)->definition()->input(0)); + + ValGraphStmtSort vg_stmt_sort(vg); + + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv4->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + ref_order.push_back(getParentExpr(tv4->axis(0), 1)); + + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + +// Sorting test with the same fusion as Indexing19 +TEST_F(IdModelTest, ValGraphStmtSort4) { + auto fusion = createFusionWithMultipleResolutionPaths(); + FusionGuard fg(fusion.get()); + auto all_tvs = ir_utils::allTvs(fusion.get()); + + // Since this fusion is not supported by ComputeAtMap, the + // validation flag must be false + IdModel id_model(fusion.get(), false, false, false); + id_model.buildExactGraph(); + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + + ValGraphStmtSort vg_stmt_sort(vg); + + auto tv1 = getTensorByName(all_tvs, 1); + auto tv2 = getTensorByName(all_tvs, 2); + auto tv4 = getTensorByName(all_tvs, 4); + auto tv5 = getTensorByName(all_tvs, 5); + auto tv6 = getTensorByName(all_tvs, 6); + auto tv8 = getTensorByName(all_tvs, 8); + auto tv9 = getTensorByName(all_tvs, 9); + auto tv10 = getTensorByName(all_tvs, 10); + + // Expected reference order: + // + // exprg{39}: Merge iS2 bS3 + // exprg{57}: Merge iS11 bS12 + // exprg{17}: Merge iS17 bS18 + // exprg{51 63}: Merge iS15 iS16 + // exprg{69 73}: Split iS1 + // exprg{9 25 33 45}: Merge iS20 iS21 + // exprg{41}: Split iS46 + // exprg{59}: Split iS61 + // exprg{19}: Merge iS29 iS19 + // exprg{53 65}: Split iS56 + // exprg{71 75}: Split iS71 + // exprg{11}: Merge iS23 iS22 + // exprg{27}: Merge iS35 bS10 + // exprg{35 47}: Split iS41 + // exprg{43}: Split iS47 + // exprg{61}: Split iS62 + // exprg{21}: Split iS30 + // exprg{55 67}: Split iS57 + // exprg{13}: Split iS24 + // exprg{29}: Split iS36 + // exprg{37 49}: Split iS42 + // exprg{23}: Split iS31 + // exprg{15}: Split iS25 + // exprg{31}: Split iS37 + + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 3)); + ref_order.push_back(getParentExpr(tv6->axis(0), 3)); + ref_order.push_back(getParentExpr(tv9->axis(0), 4)); + ref_order.push_back(getParentExpr(tv8->axis(0), 3)); + ref_order.push_back(getParentExpr(tv1->axis(0), 2)); + ref_order.push_back(getParentExpr(tv10->axis(0), 4)); + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv6->axis(0), 2)); + ref_order.push_back(getParentExpr(tv9->axis(0), 3)); + ref_order.push_back(getParentExpr(tv8->axis(0), 2)); + ref_order.push_back(getParentExpr(tv1->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 3)); + ref_order.push_back(getParentExpr(tv5->axis(0), 3)); + ref_order.push_back(getParentExpr(tv4->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + ref_order.push_back(getParentExpr(tv6->axis(0), 1)); + ref_order.push_back(getParentExpr(tv9->axis(0), 2)); + ref_order.push_back(getParentExpr(tv8->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 2)); + ref_order.push_back(getParentExpr(tv5->axis(0), 2)); + ref_order.push_back(getParentExpr(tv4->axis(0), 1)); + ref_order.push_back(getParentExpr(tv9->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 1)); + ref_order.push_back(getParentExpr(tv5->axis(0), 1)); + + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + } // namespace nvfuser