From bbd7786bc0bd7faaafa4bbecf7b6d9dea7f4fac3 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 2 Feb 2024 12:21:20 -0800 Subject: [PATCH 01/17] Add a visitor for ValGraph Used in the loop promotion analysis (#32) --- CMakeLists.txt | 1 + csrc/disjoint_set.h | 17 ++ csrc/val_graph.cpp | 58 +++++++ csrc/val_graph.h | 8 + csrc/val_graph_visitor.cpp | 118 +++++++++++++ csrc/val_graph_visitor.h | 121 +++++++++++++ test/test_id_model.cpp | 341 +++++++++++++++++++++++++++++++++++++ 7 files changed, 664 insertions(+) create mode 100644 csrc/val_graph_visitor.cpp create mode 100644 csrc/val_graph_visitor.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 61b20374012..16defcf8b81 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -202,6 +202,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp ${NVFUSER_SRCS_DIR}/val_graph.cpp + ${NVFUSER_SRCS_DIR}/val_graph_visitor.cpp ) # We don't link CUPTI for MSVC diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 25f4c183af0..d53861b2087 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -92,6 +92,15 @@ class VectorOfUniqueEntries { return any_added; } + // Remove and returns the first element in vector. Note that this is + // a costly operation as the underlying container is std::vector. + T popFront() { + T v = vector_.front(); + set_.erase(v); + vector_.erase(vector_.begin()); + return v; + } + // Returns a new VectorOfUniqueEntries with entries that are in both this and // other, order is preserved as this. VectorOfUniqueEntries computeIntersect( @@ -242,6 +251,14 @@ class VectorOfUniqueEntries { return vector_.end(); } + T& at(size_t pos) { + return vector_.at(pos); + } + + const T& at(size_t pos) const { + return vector_.at(pos); + } + std::string toString() const { std::stringstream ss; ss << "{ "; diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 3e9324a0c19..fc26db00bbd 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -99,6 +99,58 @@ std::vector ValGraph::inputGroups(const ExprGroup& expr) const { return input_groups; } +ValGroups ValGraph::getTerminatingInputs() const { + // Initialize vals to traverse + ValGroups all_vals{ + disjointValSets().disjointSets().begin(), + disjointValSets().disjointSets().end()}; + + // Initialize exprs to traverse + ExprGroups all_exprs{ + disjointExprSets().disjointSets().begin(), + disjointExprSets().disjointSets().end()}; + + // Grab all vals that are not input, i.e., having a defining expr + // within all_exprs. + // + // Note that an input Val group may be mapped with an output + // group. For example, the AlmostExact graph maps an input of split + // with the outer output if the split factor is one. Such a Val + // group is considered a terminating input as long as the input has + // no defining expression. This is for the use case of + // ValGraphVisitor. + // + // Example: + // + // [i0, i1] + // split by 1 + // [i0/1, 1, i1] + // merge + // [i0/1, 1*i1] + // + // Here, i0 and i0/1 would create a Val group of {i0, i0/1} in the + // AlmostExact graph. This group has a defining expression of the + // split, but since it's a cyclic dependency, we ignore the + // expression and consider the Val group a terminating input. + + ValGroups not_inputs; + for (const ExprGroup& expr_group : all_exprs) { + const std::vector input_groups = inputGroups(expr_group); + const std::vector output_groups = outputGroups(expr_group); + std::unordered_set input_set{ + input_groups.begin(), input_groups.end()}; + + for (const ValGroup& output_group : output_groups) { + if (input_set.count(output_group)) { + continue; + } + not_inputs.pushBack(output_group); + } + } + + return all_vals.computeSubtract(not_inputs); +} + ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { DequeOfExprGroup to_visit; for (const ValGroup& of_val_group : of) { @@ -490,6 +542,12 @@ bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return true; } +bool ValGraph::isTrivialExprGroup(const ExprGroup& expr_group) const { + return !ValGroups(inputGroups(expr_group)) + .computeIntersect(ValGroups(outputGroups(expr_group))) + .empty(); +} + void ValGraph::validateConsistency() const { // Check the consistency of the mapping information. Specifically: // 1. All ValGroup and ExprGroup sets are not empty. This may not be diff --git a/csrc/val_graph.h b/csrc/val_graph.h index f7a20e46bdb..f9a62fea47c 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -95,6 +95,9 @@ class ValGraph { std::vector outputGroups(const ExprGroup& expr) const; std::vector inputGroups(const ExprGroup& expr) const; + // Return Val groups that have no definition. + ValGroups getTerminatingInputs() const; + // Recursively traverses uses of the IdGroups in 'of' and returns all // ExprGroups that have a use in their definition of provided of IdGroups. ExprGroups allUsesOf(const ValGroups& of) const; @@ -199,6 +202,11 @@ class ValGraph { // be the only call in ValGraph to mapThroughExpr. void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); + // Returns if the expression group has an input id group that matches an + // output id group. This means traversing on this expression doesn't actually + // do anything. + bool isTrivialExprGroup(const ExprGroup& expr_group) const; + // Can't back prop through merge without making sure one input actually // matches. This can be done on a map or extent basis. // TODO: Move this to val_graph.cpp once validation_utils.cpp is diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp new file mode 100644 index 00000000000..498a56029a3 --- /dev/null +++ b/csrc/val_graph_visitor.cpp @@ -0,0 +1,118 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include + +namespace nvfuser { + +void ValGraphVisitor::traverse() { + const ValGroups terminating_inputs = graph().getTerminatingInputs(); + ValGroups to_visit_ids = terminating_inputs; + ValGroups visited_ids; + + ExprGroups to_visit_exprs; + ExprGroups visited_exprs; + + auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool { + const auto inp_groups = graph().inputGroups(expr_group); + return std::all_of( + inp_groups.begin(), inp_groups.end(), [&](ValGroup id_group) { + return visited_ids.has(id_group) || id_group->empty(); + }); + }; + + // If any input of the def expr is mapped with the val + // group itself, i.e., a trivial expr, allow visiting the + // val group first. The trivial expr group will be visited + // after the val group. + // + // Example: + // + // [i0, 1] + // merge + // [i0*1] + // map i0 and i0*1 + // ValGroups: {{i0, i0*1}, {1}} + // + // Then, {i0, i0*1} and {1} would be visited first, then the merge + // expr group would be visited. {i0, i0*1} is also an output group + // of the merge but since it's already in the visited set, it would + // not be visited again. + // + // See also IdModelTest.ValGraphStmtSort3 for a concrete example. + auto is_val_ready = [&](const ValGroup& val_group) -> bool { + const ExprGroups& unique_defs = graph().getDefinitions(val_group); + return std::all_of( + unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { + return expr_group->empty() || visited_exprs.has(expr_group) || + terminating_inputs.has(val_group) || + graph().isTrivialExprGroup(expr_group); + }); + }; + + while (!to_visit_ids.empty() || !to_visit_exprs.empty()) { + // Process expressions first as all definitions of vals have to be + // processed before we can process that val. + + // Detect if nothing has been processed which would put us in an infinite + // loop + bool something_was_processed = false; + ExprGroups still_to_visit_exprs; + + while (!to_visit_exprs.empty()) { + ExprGroup current_expr_group = to_visit_exprs.popFront(); + NVF_ERROR(!current_expr_group->empty()); + if (visited_exprs.has(current_expr_group)) { + continue; + } + + if (is_expr_ready(current_expr_group)) { + handle(current_expr_group); + + something_was_processed = true; + visited_exprs.pushBack(current_expr_group); + + to_visit_ids.pushBack(graph().outputGroups(current_expr_group)); + } else { + still_to_visit_exprs.pushBack(current_expr_group); + } + } + + std::swap(to_visit_exprs, still_to_visit_exprs); + + ValGroups still_to_visit_ids; + while (!to_visit_ids.empty()) { + auto current_id_group = to_visit_ids.popFront(); + NVF_ERROR(!current_id_group->empty()); + if (visited_ids.has(current_id_group)) { + continue; + } + + if (is_val_ready(current_id_group)) { + handle(current_id_group); + + something_was_processed = true; + visited_ids.pushBack(current_id_group); + + to_visit_exprs.pushBack(graph().getUses(current_id_group)); + } else { + still_to_visit_ids.pushBack(current_id_group); + } + } + + std::swap(to_visit_ids, still_to_visit_ids); + + NVF_ERROR( + something_was_processed || + (to_visit_ids.empty() && to_visit_exprs.empty()), + "Infinite loop entered."); + } +} + +} // namespace nvfuser diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h new file mode 100644 index 00000000000..391c08902f3 --- /dev/null +++ b/csrc/val_graph_visitor.h @@ -0,0 +1,121 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include + +namespace nvfuser { + +// Iterates through a Val Graph in topological order, calling handle on +// all Val and all Expr groups in a forward topological order. +// +// Warning: A ValGraph is not guaranteed to be a DAG. In fact, the +// AlmostExact and Permissive graphs would have cycles with a ValGroup +// and an ExprGroup. For example: +// +// [i0, 1] +// merge +// [i0*1] +// Current ValGroups: {{i0}, {1}, {i0*1}} +// map i0 and i0*1 as they effectively have the same extent +// Final ValGroups: {{i0, i0*1}, {1}} +// +// Here, the merge expr is the user of i0 and the definition of +// i0*1. Since i0 and i0*1 are mapped, the dependency chain looks +// like: +// +// {i0, i0*1} ----> {merge} ----> {i0, i0*1} +// use def +// +// These ExprGroups are called trivial ExprGroups (see also +// ValGraph::isTrivialExprGroup). +// +// Strictly speaking, these cycles mean there's no valid topological +// order anymore. In our use cases for IdModel, however, it's likely +// sufficient to return an ordering such as: +// +// {i0, i0*1} -> {merge} +// +// I.e., we visit {i0, i0*1} first even though {merge} is technically +// a definition. +// +// Another alternative may be simply giving up when such a cycle is +// detected, which may be more preferrable as it would be less +// confusing. At this moment, this visitor is only used with graphs +// with no such cycle. Should be revisited when necessary. +// +// Warning: This is not a great iterator if there's a desire to minimize paths +// traveled to simply visit all ValGroups in order. See ExprsBetween to see how +// we might minimize paths. +class ValGraphVisitor { + public: + ValGraphVisitor() = delete; + + ValGraphVisitor& operator=(const ValGraphVisitor& other) = delete; + + ValGraphVisitor& operator=(ValGraphVisitor&& other) = delete; + + virtual ~ValGraphVisitor() = default; + + protected: + ValGraphVisitor(const ValGraph& val_graph) : val_graph_(val_graph) {} + + ValGraphVisitor(const ValGraphVisitor& other) = default; + + ValGraphVisitor(ValGraphVisitor&& other) = default; + + virtual void handle(const ValGroup& id_group) = 0; + virtual void handle(const ExprGroup& expr_group) = 0; + + void traverse(); + + const ValGraph& graph() { + return val_graph_; + }; + + private: + const ValGraph& val_graph_; +}; + +// Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor. +class ValGraphStmtSort : public ValGraphVisitor { + public: + ValGraphStmtSort(const ValGraph& val_graph) : ValGraphVisitor(val_graph) { + ValGraphVisitor::traverse(); + } + + // Return non-reference so that code like below can work + // for (auto expr_group: IdGraphStmtSort(graph).exprs()) + ExprGroups exprs() const { + return sorted_exprs_; + } + + ValGroups vals() const { + return sorted_vals_; + } + + ~ValGraphStmtSort() override = default; + + protected: + using ValGraphVisitor::handle; + + void handle(const ValGroup& val_group) override { + sorted_vals_.pushBack(val_group); + } + + void handle(const ExprGroup& expr_group) override { + sorted_exprs_.pushBack(expr_group); + } + + ExprGroups sorted_exprs_; + ValGroups sorted_vals_; +}; + +} // namespace nvfuser diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index fc823316564..f9abd51b532 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace nvfuser { @@ -37,4 +38,344 @@ TEST_F(IdModelTest, DetectSelfMapping) { ::testing::HasSubstr("!hasSelfMapping"))); } +namespace { + +// Get n-th parent expr traversing through the first input of each +// parent +Expr* getParentExpr(Val* val, int n) { + for (int i = 0; i < n - 1; ++i) { + NVF_ERROR(val->definition() != nullptr); + val = val->definition()->input(0); + } + NVF_ERROR(val->definition() != nullptr); + return val->definition(); +}; + +TensorView* getTensorByName( + const std::vector& tvs, + StmtNameType name) { + if (auto it = std::find_if( + tvs.begin(), + tvs.end(), + [&](TensorView* tv) { return tv->name() == name; }); + it != tvs.end()) { + return *it; + } else { + return nullptr; + } +} + +// Create a fusion where we're missing a valid concrete id so the compute at map +// processing will fail. We need to be able to create the concrete ID not just +// look for one. It is not yet possible to lower this fusion as the +// current indexing cannot generate correct indices. Also used in +// FusionIndeixing19 as well as Example 2 in the design doc about Loop +// Promotion Analysis. +std::unique_ptr createFusionWithMultipleResolutionPaths() { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({7}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = broadcast(tv1, {false, true}); + + auto tv3 = makeConcreteTensor({7, 11}); + fusion.addInput(tv3); + + auto tv4 = add(tv3, tv2); + auto tv5 = broadcast(tv4, {false, false, true}); + // tv4[7, 11, 1] + + auto tv6 = broadcast(tv1, {false, true}); + + auto tv7 = makeConcreteTensor({7, 13}); + fusion.addInput(tv7); + auto tv8 = add(tv7, tv6); + auto tv9 = broadcast(tv8, {false, true, false}); + // tv9[7, 1, 13] + + auto tv10 = add(tv5, tv9); + fusion.addOutput(tv10); + + // tv10[7, 11, 13] + tv10->merge(0)->merge(0); + // tv10[7*11*13] + tv10->split(0, 5)->split(0, 3); + // tv10[7*11*13//5//3, 3, 5] + + TransformPropagatorWithCheck propagator(tv10); + MaxRootDomainInfoSpanningTree(tv10).traverse(&propagator); + + std::vector tensors_to_inline{tv1, tv2, tv4, tv6, tv8}; + for (auto tensor : tensors_to_inline) { + tensor->inlineAt(1); + } + + return fusion_ptr; +} + +// Check the results of ValGraphStmtSort. Only the ordering of +// ExprGroups is checked for now as it's likely sufficient. +// +// ref_order: The order must be exactly the +// same as indicated by this list. While there can be different +// order that still satisfy the topologial ordering, we also need +// deterministic ordering, so the results should be always the same. +void checkSortingResults( + const ValGraph& graph, + const ExprGroups& sorted_expr_groups, + const ValGroups& sorted_val_groups, + const std::vector& ref_order) { + // Make sure sorted_val_groups cover all Expr groups + const std::unordered_set& ref_expr_group_set{ + graph.disjointExprSets().disjointSets().begin(), + graph.disjointExprSets().disjointSets().end()}; + std::unordered_set sorted_expr_group_set{ + sorted_expr_groups.begin(), sorted_expr_groups.end()}; + ASSERT_EQ(sorted_expr_group_set, ref_expr_group_set) + << "Mismatched ExprGroups."; + + // Make sure sorted_val_groups covers all Val groups + const std::unordered_set& ref_val_group_set{ + graph.disjointValSets().disjointSets().begin(), + graph.disjointValSets().disjointSets().end()}; + std::unordered_set sorted_val_group_set{ + sorted_val_groups.begin(), sorted_val_groups.end()}; + ASSERT_EQ(sorted_val_group_set, ref_val_group_set) << "Mismatched ValGroups."; + + // Check the ordering + ASSERT_EQ(sorted_expr_groups.size(), ref_order.size()); + for (const auto i : c10::irange(ref_order.size())) { + Expr* ref_expr = ref_order.at(i); + const ExprGroup& eg = sorted_expr_groups.at(i); + ASSERT_TRUE(eg->has(ref_expr)) + << "Expected: " << nvfuser::toString(graph.toGroup(ref_expr)) + << ". Actual: " << nvfuser::toString(eg); + } +} + +} // namespace + +// Sorting test with a trivial fusion +TEST_F(IdModelTest, ValGraphStmtSort1) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + // No ID expr yet + { + IdModel id_model(&fusion); + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), {}); + } + + tv2->merge(0)->split(0, 4); + + TransformPropagator propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + { + IdModel id_model(&fusion); + + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + + // Reference expr order: merge, split + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + + checkSortingResults( + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); + } +} + +// Sorting test wth a disconnected graph +TEST_F(IdModelTest, ValGraphStmtSort2) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + // Note that the two groups of tensors, {tv0, tv1} and {tv2, tv3}, + // are not connected + + for (auto tv : ir_utils::allTvs(&fusion)) { + tv->merge(0)->split(0, 4); + } + + IdModel id_model(&fusion); + + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + + std::vector ref_order; + ref_order.push_back(getParentExpr(tv1->axis(0), 2)); + ref_order.push_back(getParentExpr(tv3->axis(0), 2)); + ref_order.push_back(getParentExpr(tv1->axis(0), 1)); + ref_order.push_back(getParentExpr(tv3->axis(0), 1)); + + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + +// Sorting with trivial ExprGroup, i.e., ExprGroup whose input and +// output are mapped as the same ValGroup. It's effectively a cyclic +// dependency and the graph is no longer a DAG. +TEST_F(IdModelTest, ValGraphStmtSort3) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + auto tv3 = makeSymbolicTensor(2); + fusion.addInput(tv3); + auto tv4 = set(tv3); + fusion.addOutput(tv4); + + // Merge adn split by one. The split input and output will be mapped. + for (auto tv : {tv0, tv1, tv2}) { + tv->merge(0)->split(0, 1); + } + + // Also test an isolated trivial expr. Note that tv3 and tv4 are not + // connected with tv0, tv1 and tv2. + tv4->merge(0)->split(0, 1); + + IdModel id_model(&fusion); + ValGraph vg = id_model.idGraph(IdMappingMode::EXACT); + + // Map the split-by-1 input and output + vg.mapVals(tv2->axis(0), tv2->axis(0)->definition()->input(0)); + vg.mapVals(tv4->axis(0), tv4->axis(0)->definition()->input(0)); + + ValGraphStmtSort vg_stmt_sort(vg); + + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv4->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + ref_order.push_back(getParentExpr(tv4->axis(0), 1)); + + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + +// Sorting test with the same fusion as Indexing19 +TEST_F(IdModelTest, ValGraphStmtSort4) { + auto fusion = createFusionWithMultipleResolutionPaths(); + FusionGuard fg(fusion.get()); + auto all_tvs = ir_utils::allTvs(fusion.get()); + + IdModel id_model(fusion.get(), true, false, false); + + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + + ValGraphStmtSort vg_stmt_sort(vg); + + auto tv1 = getTensorByName(all_tvs, 1); + auto tv2 = getTensorByName(all_tvs, 2); + auto tv4 = getTensorByName(all_tvs, 4); + auto tv5 = getTensorByName(all_tvs, 5); + auto tv6 = getTensorByName(all_tvs, 6); + auto tv8 = getTensorByName(all_tvs, 8); + auto tv9 = getTensorByName(all_tvs, 9); + auto tv10 = getTensorByName(all_tvs, 10); + + // Expected reference order: + // + // exprg{39}: Merge: iS2{7} and bS3{1} -> iS46{( 7 * 1 )} + // exprg{57}: Merge: iS11{7} and bS12{1} -> iS61{( 7 * 1 )} + // exprg{17}: Merge: iS17{7} and bS18{1} -> iS29{( 7 * 1 )} + // exprg{69 73 89}: Split: iS1{7} by factor 5 -> iS71{( ceilDiv(7, 5) )}, + // iS72{5}, start offset: 0, stop offset: 0 exprg{51 63 93}: Merge: iS15{7} + // and iS16{13} -> iS56{( 7 * 13 )} exprg{9 25 33 45 91 95}: Merge: iS20{7} + // and iS21{11} -> iS23{( 7 * 11 )} exprg{27}: Merge: iS35{( 7 * 11 )} and + // bS10{1} -> iS36{( ( 7 * 11 ) * 1 )} exprg{19}: Merge: iS29{( 7 * 1 )} and + // iS19{13} -> iS30{( ( 7 * 1 ) * 13 )} exprg{11 77 79 99}: Merge: iS23{( 7 * + // 11 )} and iS22{13} -> iS24{( ( 7 * 11 ) * 13 )} exprg{41}: Split: iS46{( 7 + // * 1 )} by factor 5 -> iS47{( ceilDiv(( 7 * 1 ), 5) )}, iS48{5}, start + // offset: 0, stop offset: 0 exprg{59}: Split: iS61{( 7 * 1 )} by factor 5 -> + // iS62{( ceilDiv(( 7 * 1 ), 5) )}, iS63{5}, start offset: 0, stop offset: 0 + // exprg{71 75 101}: Split: iS71{( ceilDiv(7, 5) )} by factor 3 -> iS73{( + // ceilDiv(( ceilDiv(7, 5) ), 3) )}, iS74{3}, start offset: 0, stop offset: 0 + // exprg{53 65 109}: Split: iS56{( 7 * 13 )} by factor 5 -> iS57{( ceilDiv(( 7 + // * 13 ), 5) )}, iS58{5}, start offset: 0, stop offset: 0 exprg{35 47 105}: + // Split: iS41{( 7 * 11 )} by factor 5 -> iS42{( ceilDiv(( 7 * 11 ), 5) )}, + // iS43{5}, start offset: 0, stop offset: 0 exprg{29}: Split: iS36{( ( 7 * 11 + // ) * 1 )} by factor 5 -> iS37{( ceilDiv(( ( 7 * 11 ) * 1 ), 5) )}, iS38{5}, + // start offset: 0, stop offset: 0 exprg{21}: Split: iS30{( ( 7 * 1 ) * 13 )} + // by factor 5 -> iS31{( ceilDiv(( ( 7 * 1 ) * 13 ), 5) )}, iS32{5}, start + // offset: 0, stop offset: 0 exprg{13 81 83 97 103 107 111 115 117 119 121}: + // Split: iS24{( ( 7 * 11 ) * 13 )} by factor 5 -> iS25{( ceilDiv(( ( 7 * 11 ) + // * 13 ), 5) )}, iS26{5}, start offset: 0, stop offset: 0 exprg{43}: Split: + // iS47{( ceilDiv(( 7 * 1 ), 5) )} by factor 3 -> iS49{( ceilDiv(( ceilDiv(( 7 + // * 1 ), 5) ), 3) )}, iS50{3}, start offset: 0, stop offset: 0 exprg{61}: + // Split: iS62{( ceilDiv(( 7 * 1 ), 5) )} by factor 3 -> iS64{( ceilDiv(( + // ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS65{3}, start offset: 0, stop offset: 0 + // exprg{55 67 129}: Split: iS57{( ceilDiv(( 7 * 13 ), 5) )} by factor 3 -> + // iS59{( ceilDiv(( ceilDiv(( 7 * 13 ), 5) ), 3) )}, iS60{3}, start offset: 0, + // stop offset: 0 exprg{37 49 125}: Split: iS42{( ceilDiv(( 7 * 11 ), 5) )} by + // factor 3 -> iS44{( ceilDiv(( ceilDiv(( 7 * 11 ), 5) ), 3) )}, iS45{3}, + // start offset: 0, stop offset: 0 exprg{31}: Split: iS37{( ceilDiv(( ( 7 * 11 + // ) * 1 ), 5) )} by factor 3 -> iS39{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 1 ), + // 5) ), 3) )}, iS40{3}, start offset: 0, stop offset: 0 exprg{23}: Split: + // iS31{( ceilDiv(( ( 7 * 1 ) * 13 ), 5) )} by factor 3 -> iS33{( ceilDiv(( + // ceilDiv(( ( 7 * 1 ) * 13 ), 5) ), 3) )}, iS34{3}, start offset: 0, stop + // offset: 0 exprg{15 85 87 113 123 127 131 133 135 137 139}: Split: iS25{( + // ceilDiv(( ( 7 * 11 ) * 13 ), 5) )} by factor 3 -> iS27{( ceilDiv(( + // ceilDiv(( ( 7 * 11 ) * 13 ), 5) ), 3) )}, iS28{3}, start offset: 0, stop + // offset: 0 + + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 3)); + ref_order.push_back(getParentExpr(tv6->axis(0), 3)); + ref_order.push_back(getParentExpr(tv9->axis(0), 4)); + ref_order.push_back(getParentExpr(tv1->axis(0), 2)); + ref_order.push_back(getParentExpr(tv8->axis(0), 3)); + ref_order.push_back(getParentExpr(tv10->axis(0), 4)); + ref_order.push_back(getParentExpr(tv5->axis(0), 3)); + ref_order.push_back(getParentExpr(tv9->axis(0), 3)); + ref_order.push_back(getParentExpr(tv10->axis(0), 3)); + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv6->axis(0), 2)); + ref_order.push_back(getParentExpr(tv1->axis(0), 1)); + ref_order.push_back(getParentExpr(tv8->axis(0), 2)); + ref_order.push_back(getParentExpr(tv4->axis(0), 2)); + ref_order.push_back(getParentExpr(tv5->axis(0), 2)); + ref_order.push_back(getParentExpr(tv9->axis(0), 2)); + ref_order.push_back(getParentExpr(tv10->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + ref_order.push_back(getParentExpr(tv6->axis(0), 1)); + ref_order.push_back(getParentExpr(tv8->axis(0), 1)); + ref_order.push_back(getParentExpr(tv4->axis(0), 1)); + ref_order.push_back(getParentExpr(tv5->axis(0), 1)); + ref_order.push_back(getParentExpr(tv9->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 1)); + + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + } // namespace nvfuser From 42b1b238848214a963bdbde235f596ae3d15d7c0 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 2 Feb 2024 13:24:50 -0800 Subject: [PATCH 02/17] comments --- test/test_id_model.cpp | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index f9abd51b532..7d1e7131ae5 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -172,7 +172,9 @@ TEST_F(IdModelTest, ValGraphStmtSort1) { auto tv2 = add(tv0, tv1); fusion.addOutput(tv2); - // No ID expr yet + // No ID expr yet. checkSortingResults validates the exprssion + // order, but since there's no expr, it just makes sure exprs() and + // vals() return all the val and expr groups. { IdModel id_model(&fusion); const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); @@ -180,11 +182,14 @@ TEST_F(IdModelTest, ValGraphStmtSort1) { checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), {}); } + // Add ID exprs. Just apply a merge-and-split pattern to all + // tensors. tv2->merge(0)->split(0, 4); - TransformPropagator propagator(tv2); MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + // The exact graph should just map all IDs of the tensors. Ther + // ordering of the exprs should be the merge and then the split. { IdModel id_model(&fusion); @@ -223,6 +228,21 @@ TEST_F(IdModelTest, ValGraphStmtSort2) { tv->merge(0)->split(0, 4); } + // Since the two tensors are disconnected, there's no ordering + // between the ID exprs of the two tensor groups. So, the correct + // ordering should have the merge exprs before the split exprs, but + // there's no order between the tv1 and tv3 exprs. For example, + // these are all valid: + // + // tv1 merge -> tv3 merge -> tv1 split -> tv3 split + // tv1 merge -> tv1 split -> tv3 merge -> tv3 split + // tv3 merge -> tv3 split -> tv1 merge -> tv1 split + // tv3 merge -> tv1 merge -> tv3 split -> tv1 split + // + // Here, the actual order returned by ValGraphStmtSort is the first + // one. Since it should be deterministic, we check if the returned + // expr vector is indeed ordered that way. + IdModel id_model(&fusion); const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); @@ -289,6 +309,8 @@ TEST_F(IdModelTest, ValGraphStmtSort4) { FusionGuard fg(fusion.get()); auto all_tvs = ir_utils::allTvs(fusion.get()); + // Since this fusion is not supported by ComputeAtMap, the + // validation flag must be false IdModel id_model(fusion.get(), true, false, false); const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); From c6a69282c87b7ed9c6aa7b1aa83c508b81870d1d Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 5 Feb 2024 16:08:03 -0800 Subject: [PATCH 03/17] Remove a stale comment --- csrc/val_graph.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/val_graph.h b/csrc/val_graph.h index f9a62fea47c..0c306f72e44 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -203,8 +203,7 @@ class ValGraph { void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); // Returns if the expression group has an input id group that matches an - // output id group. This means traversing on this expression doesn't actually - // do anything. + // output id group. bool isTrivialExprGroup(const ExprGroup& expr_group) const; // Can't back prop through merge without making sure one input actually From a8a3bbf1c5d3b26196c4a16db9236828174fd806 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 5 Feb 2024 16:57:08 -0800 Subject: [PATCH 04/17] Avoid using popFront --- csrc/id_model/visitor.cpp | 162 +++++++++++++++++++++++++++++++++++++ csrc/id_model/visitor.h | 97 ++++++++++++++++++++++ csrc/val_graph_visitor.cpp | 28 ++++--- 3 files changed, 277 insertions(+), 10 deletions(-) create mode 100644 csrc/id_model/visitor.cpp create mode 100644 csrc/id_model/visitor.h diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp new file mode 100644 index 00000000000..f81ce4177d9 --- /dev/null +++ b/csrc/id_model/visitor.cpp @@ -0,0 +1,162 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +namespace nvfuser { + +void IdGraphVisitor::traverse() { + ValGroups all_ids; + ExprGroups all_exprs; + { + // Initialize IDs to traverse. If sub_selection is provided, only + // traverse IDs that are included in the set are traversed. + if (sub_selection_.empty()) { + all_ids = ValGroups( + graph().disjointValSets().disjointSets().begin(), + graph().disjointValSets().disjointSets().end()); + } else { + for (auto id : sub_selection_) { + if (graph().hasGroup(id)) { + all_ids.pushBack(graph().toGroup(id)); + } + } + } + + // Initialize exprs to traverse. If sub_selection is provided, + // only traverse exprs that are strictly contained within the provided + // sub_selection. Exprs are excluded if any of inputs or outputs + // is not in sub_selection. + if (sub_selection_.empty()) { + all_exprs = ExprGroups( + graph().disjointExprSets().disjointSets().begin(), + graph().disjointExprSets().disjointSets().end()); + } else { + for (const ValGroup& id_group : all_ids) { + for (const ExprGroup& def : graph().getDefinitions(id_group)) { + if (all_exprs.has(def)) { + continue; + } + auto inp_groups = ValGroups(graph().inputGroups(def)); + auto out_groups = ValGroups(graph().outputGroups(def)); + if (inp_groups.computeSubtract(all_ids).empty() && + out_groups.computeSubtract(all_ids).empty()) { + all_exprs.pushBack(def); + } + } + } + } + } + // There could be IterDomains in from or to that are between other from and + // to nodes. Make sure to clear those out. + ValGroups terminating_inputs; + ValGroups terminating_outputs; + + { + ValGroups not_inputs; + ValGroups not_outputs; + for (const ExprGroup& expr_group : all_exprs) { + if (graph().isTrivialExprGroup(expr_group)) { + // Expression is just a loop to its current group, ignore + continue; + } + + not_inputs.pushBack(graph().outputGroups(expr_group)); + not_outputs.pushBack(graph().inputGroups(expr_group)); + } + + terminating_inputs = + ValGroups(all_ids.begin(), all_ids.end()).computeSubtract(not_inputs); + + terminating_outputs = + ValGroups(all_ids.begin(), all_ids.end()).computeSubtract(not_outputs); + } + + ValGroups to_visit_ids = terminating_inputs; + ValGroups visited_ids; + + ExprGroups to_visit_exprs; + ExprGroups visited_exprs; + + auto is_expr_ready = [&](const ExprGroup& expr_group) { + auto inp_groups = graph().inputGroups(expr_group); + return std::all_of( + inp_groups.begin(), inp_groups.end(), [&](ValGroup id_group) { + return visited_ids.has(id_group) || id_group->empty(); + }); + }; + + auto is_id_ready = [&](const ValGroup& id_group) { + const ExprGroups& unique_defs = graph().getDefinitions(id_group); + return std::all_of( + unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { + return expr_group->empty() || visited_exprs.has(expr_group) || + graph().isTrivialExprGroup(expr_group); + }); + }; + + while (!to_visit_ids.empty() || !to_visit_exprs.empty()) { + // Process expressions first as all definitions of iter domains have to be + // processed before we can process that iter domain. + + // Detect if nothing has been processed which would put us in an infinite + // loop + bool something_was_processed = false; + ExprGroups still_to_visit_exprs; + + while (!to_visit_exprs.empty()) { + ExprGroup current_expr_group = to_visit_exprs.popFront(); + NVF_ERROR(!current_expr_group->empty()); + if (visited_exprs.has(current_expr_group)) { + continue; + } + + if (is_expr_ready(current_expr_group)) { + handle(current_expr_group); + + something_was_processed = true; + visited_exprs.pushBack(current_expr_group); + + to_visit_ids.pushBack(graph().outputGroups(current_expr_group)); + } else { + still_to_visit_exprs.pushBack(current_expr_group); + } + } + + std::swap(to_visit_exprs, still_to_visit_exprs); + + ValGroups still_to_visit_ids; + while (!to_visit_ids.empty()) { + auto current_id_group = to_visit_ids.popFront(); + NVF_ERROR(!current_id_group->empty()); + if (visited_ids.has(current_id_group)) { + continue; + } + + if (is_id_ready(current_id_group)) { + handle(current_id_group); + + something_was_processed = true; + visited_ids.pushBack(current_id_group); + + if (!terminating_outputs.has(current_id_group)) { + const ExprGroups& uses = graph().getUses(current_id_group); + to_visit_exprs.pushBack(uses); + } + } else { + still_to_visit_ids.pushBack(current_id_group); + } + } + std::swap(to_visit_ids, still_to_visit_ids); + + NVF_ERROR( + something_was_processed || + (to_visit_ids.empty() && to_visit_exprs.empty()), + "Infinite loop entered."); + } +} +} // namespace nvfuser diff --git a/csrc/id_model/visitor.h b/csrc/id_model/visitor.h new file mode 100644 index 00000000000..2c13b9efae0 --- /dev/null +++ b/csrc/id_model/visitor.h @@ -0,0 +1,97 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include + +namespace nvfuser { + +// Iterates through an IterDomain Graph in topological order, calling handle on +// all Id and all Expr groups in a forward topological order. +// +// Warning: Expr groups that have an input and output in the same ValGroup are +// ignored. +// +// Warning: This is not a great iterator if there's a desire to minimize paths +// traveled to simply visit all ValGroups in order. See ExprsBetween to see how +// we might minimize paths. +class IdGraphVisitor { + public: + IdGraphVisitor() = delete; + + IdGraphVisitor& operator=(const IdGraphVisitor& other) = delete; + + IdGraphVisitor& operator=(IdGraphVisitor&& other) = delete; + + virtual ~IdGraphVisitor() = default; + + protected: + // If sub_selection is assumed to be a set of iter domains by which form a + // sub-regrion of the IdGraph provided. Only that sub-region will be visited. + IdGraphVisitor( + const ValGraph& id_graph, + const VectorOfUniqueEntries sub_selection = {}) + : id_graph_(id_graph), sub_selection_(sub_selection) {} + + IdGraphVisitor(const IdGraphVisitor& other) = default; + + IdGraphVisitor(IdGraphVisitor&& other) = default; + + virtual void handle(ValGroup id_group) = 0; + virtual void handle(ExprGroup expr_group) = 0; + + void traverse(); + + const ValGraph& graph() { + return id_graph_; + }; + + private: + const ValGraph& id_graph_; + const VectorOfUniqueEntries sub_selection_; +}; + +// Statement sorting based on IdGraphVisitor, see warnings to IdGraph Visitor. +class IdGraphStmtSort : public IdGraphVisitor { + public: + IdGraphStmtSort( + const ValGraph& id_graph, + const VectorOfUniqueEntries sub_selection = {}) + : IdGraphVisitor(id_graph, sub_selection) { + IdGraphVisitor::traverse(); + } + + // Return non-reference so that code like below can work + // for (auto expr_group: IdGraphStmtSort(graph).exprs()) + ExprGroups exprs() const { + return sorted_exprs_; + } + + ValGroups ids() const { + return sorted_ids_; + } + + ~IdGraphStmtSort() override = default; + + protected: + using IdGraphVisitor::handle; + void handle(ValGroup id_group) override { + sorted_ids_.pushBack(id_group); + } + + void handle(ExprGroup expr_group) override { + sorted_exprs_.pushBack(expr_group); + } + + ExprGroups sorted_exprs_; + ValGroups sorted_ids_; +}; + +} // namespace nvfuser diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 498a56029a3..76d701444c2 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -13,10 +13,11 @@ namespace nvfuser { void ValGraphVisitor::traverse() { const ValGroups terminating_inputs = graph().getTerminatingInputs(); - ValGroups to_visit_ids = terminating_inputs; + std::deque to_visit_ids( + terminating_inputs.begin(), terminating_inputs.end()); ValGroups visited_ids; - ExprGroups to_visit_exprs; + std::deque to_visit_exprs; ExprGroups visited_exprs; auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool { @@ -63,10 +64,11 @@ void ValGraphVisitor::traverse() { // Detect if nothing has been processed which would put us in an infinite // loop bool something_was_processed = false; - ExprGroups still_to_visit_exprs; + std::deque still_to_visit_exprs; while (!to_visit_exprs.empty()) { - ExprGroup current_expr_group = to_visit_exprs.popFront(); + ExprGroup current_expr_group = to_visit_exprs.front(); + to_visit_exprs.pop_front(); NVF_ERROR(!current_expr_group->empty()); if (visited_exprs.has(current_expr_group)) { continue; @@ -78,17 +80,21 @@ void ValGraphVisitor::traverse() { something_was_processed = true; visited_exprs.pushBack(current_expr_group); - to_visit_ids.pushBack(graph().outputGroups(current_expr_group)); + for (const ValGroup& output_group : + graph().outputGroups(current_expr_group)) { + to_visit_ids.push_back(output_group); + } } else { - still_to_visit_exprs.pushBack(current_expr_group); + still_to_visit_exprs.push_back(current_expr_group); } } std::swap(to_visit_exprs, still_to_visit_exprs); - ValGroups still_to_visit_ids; + std::deque still_to_visit_ids; while (!to_visit_ids.empty()) { - auto current_id_group = to_visit_ids.popFront(); + auto current_id_group = to_visit_ids.front(); + to_visit_ids.pop_front(); NVF_ERROR(!current_id_group->empty()); if (visited_ids.has(current_id_group)) { continue; @@ -100,9 +106,11 @@ void ValGraphVisitor::traverse() { something_was_processed = true; visited_ids.pushBack(current_id_group); - to_visit_exprs.pushBack(graph().getUses(current_id_group)); + for (const ExprGroup& use_group : graph().getUses(current_id_group)) { + to_visit_exprs.push_back(use_group); + } } else { - still_to_visit_ids.pushBack(current_id_group); + still_to_visit_ids.push_back(current_id_group); } } From ebd14d230d8b502ffcabd6b4c8ca3c03187c5562 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 5 Feb 2024 17:50:31 -0800 Subject: [PATCH 05/17] Update the reference order (still valid) --- test/test_id_model.cpp | 86 +++++++++++++++++------------------------- 1 file changed, 35 insertions(+), 51 deletions(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 7d1e7131ae5..39dae8fce53 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -153,8 +153,10 @@ void checkSortingResults( Expr* ref_expr = ref_order.at(i); const ExprGroup& eg = sorted_expr_groups.at(i); ASSERT_TRUE(eg->has(ref_expr)) - << "Expected: " << nvfuser::toString(graph.toGroup(ref_expr)) - << ". Actual: " << nvfuser::toString(eg); + << "Mismatch detected at " << i << "-th expr group. " + << "Expected: " << nvfuser::toString(graph.toGroup(ref_expr)) << ", " + << ref_expr->toString() << ". Actual: " << nvfuser::toString(eg) << ", " + << eg->front()->toString(); } } @@ -311,8 +313,8 @@ TEST_F(IdModelTest, ValGraphStmtSort4) { // Since this fusion is not supported by ComputeAtMap, the // validation flag must be false - IdModel id_model(fusion.get(), true, false, false); - + IdModel id_model(fusion.get(), false, false, false); + id_model.buildExactGraph(); const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); ValGraphStmtSort vg_stmt_sort(vg); @@ -328,74 +330,56 @@ TEST_F(IdModelTest, ValGraphStmtSort4) { // Expected reference order: // - // exprg{39}: Merge: iS2{7} and bS3{1} -> iS46{( 7 * 1 )} - // exprg{57}: Merge: iS11{7} and bS12{1} -> iS61{( 7 * 1 )} - // exprg{17}: Merge: iS17{7} and bS18{1} -> iS29{( 7 * 1 )} - // exprg{69 73 89}: Split: iS1{7} by factor 5 -> iS71{( ceilDiv(7, 5) )}, - // iS72{5}, start offset: 0, stop offset: 0 exprg{51 63 93}: Merge: iS15{7} - // and iS16{13} -> iS56{( 7 * 13 )} exprg{9 25 33 45 91 95}: Merge: iS20{7} - // and iS21{11} -> iS23{( 7 * 11 )} exprg{27}: Merge: iS35{( 7 * 11 )} and - // bS10{1} -> iS36{( ( 7 * 11 ) * 1 )} exprg{19}: Merge: iS29{( 7 * 1 )} and - // iS19{13} -> iS30{( ( 7 * 1 ) * 13 )} exprg{11 77 79 99}: Merge: iS23{( 7 * - // 11 )} and iS22{13} -> iS24{( ( 7 * 11 ) * 13 )} exprg{41}: Split: iS46{( 7 - // * 1 )} by factor 5 -> iS47{( ceilDiv(( 7 * 1 ), 5) )}, iS48{5}, start - // offset: 0, stop offset: 0 exprg{59}: Split: iS61{( 7 * 1 )} by factor 5 -> - // iS62{( ceilDiv(( 7 * 1 ), 5) )}, iS63{5}, start offset: 0, stop offset: 0 - // exprg{71 75 101}: Split: iS71{( ceilDiv(7, 5) )} by factor 3 -> iS73{( - // ceilDiv(( ceilDiv(7, 5) ), 3) )}, iS74{3}, start offset: 0, stop offset: 0 - // exprg{53 65 109}: Split: iS56{( 7 * 13 )} by factor 5 -> iS57{( ceilDiv(( 7 - // * 13 ), 5) )}, iS58{5}, start offset: 0, stop offset: 0 exprg{35 47 105}: - // Split: iS41{( 7 * 11 )} by factor 5 -> iS42{( ceilDiv(( 7 * 11 ), 5) )}, - // iS43{5}, start offset: 0, stop offset: 0 exprg{29}: Split: iS36{( ( 7 * 11 - // ) * 1 )} by factor 5 -> iS37{( ceilDiv(( ( 7 * 11 ) * 1 ), 5) )}, iS38{5}, - // start offset: 0, stop offset: 0 exprg{21}: Split: iS30{( ( 7 * 1 ) * 13 )} - // by factor 5 -> iS31{( ceilDiv(( ( 7 * 1 ) * 13 ), 5) )}, iS32{5}, start - // offset: 0, stop offset: 0 exprg{13 81 83 97 103 107 111 115 117 119 121}: - // Split: iS24{( ( 7 * 11 ) * 13 )} by factor 5 -> iS25{( ceilDiv(( ( 7 * 11 ) - // * 13 ), 5) )}, iS26{5}, start offset: 0, stop offset: 0 exprg{43}: Split: - // iS47{( ceilDiv(( 7 * 1 ), 5) )} by factor 3 -> iS49{( ceilDiv(( ceilDiv(( 7 - // * 1 ), 5) ), 3) )}, iS50{3}, start offset: 0, stop offset: 0 exprg{61}: - // Split: iS62{( ceilDiv(( 7 * 1 ), 5) )} by factor 3 -> iS64{( ceilDiv(( - // ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS65{3}, start offset: 0, stop offset: 0 - // exprg{55 67 129}: Split: iS57{( ceilDiv(( 7 * 13 ), 5) )} by factor 3 -> - // iS59{( ceilDiv(( ceilDiv(( 7 * 13 ), 5) ), 3) )}, iS60{3}, start offset: 0, - // stop offset: 0 exprg{37 49 125}: Split: iS42{( ceilDiv(( 7 * 11 ), 5) )} by - // factor 3 -> iS44{( ceilDiv(( ceilDiv(( 7 * 11 ), 5) ), 3) )}, iS45{3}, - // start offset: 0, stop offset: 0 exprg{31}: Split: iS37{( ceilDiv(( ( 7 * 11 - // ) * 1 ), 5) )} by factor 3 -> iS39{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 1 ), - // 5) ), 3) )}, iS40{3}, start offset: 0, stop offset: 0 exprg{23}: Split: - // iS31{( ceilDiv(( ( 7 * 1 ) * 13 ), 5) )} by factor 3 -> iS33{( ceilDiv(( - // ceilDiv(( ( 7 * 1 ) * 13 ), 5) ), 3) )}, iS34{3}, start offset: 0, stop - // offset: 0 exprg{15 85 87 113 123 127 131 133 135 137 139}: Split: iS25{( - // ceilDiv(( ( 7 * 11 ) * 13 ), 5) )} by factor 3 -> iS27{( ceilDiv(( - // ceilDiv(( ( 7 * 11 ) * 13 ), 5) ), 3) )}, iS28{3}, start offset: 0, stop - // offset: 0 + // exprg{39}: Merge iS2 bS3 + // exprg{57}: Merge iS11 bS12 + // exprg{17}: Merge iS17 bS18 + // exprg{51 63}: Merge iS15 iS16 + // exprg{69 73}: Split iS1 + // exprg{9 25 33 45}: Merge iS20 iS21 + // exprg{27}: Merge iS35 bS10 + // exprg{11}: Merge iS23 iS22 + // exprg{19}: Merge iS29 iS19 + // exprg{41}: Split iS46 + // exprg{59}: Split iS61 + // exprg{53 65}: Split iS56 + // exprg{71 75}: Split iS71 + // exprg{35 47}: Split iS41 + // exprg{29}: Split iS36 + // exprg{13}: Split iS24 + // exprg{21}: Split iS30 + // exprg{43}: Split iS47 + // exprg{61}: Split iS62 + // exprg{55 67}: Split iS57 + // exprg{37 49}: Split iS42 + // exprg{31}: Split iS37 + // exprg{15}: Split iS25 + // exprg{23}: Split iS31 std::vector ref_order; ref_order.push_back(getParentExpr(tv2->axis(0), 3)); ref_order.push_back(getParentExpr(tv6->axis(0), 3)); ref_order.push_back(getParentExpr(tv9->axis(0), 4)); - ref_order.push_back(getParentExpr(tv1->axis(0), 2)); ref_order.push_back(getParentExpr(tv8->axis(0), 3)); + ref_order.push_back(getParentExpr(tv1->axis(0), 2)); ref_order.push_back(getParentExpr(tv10->axis(0), 4)); ref_order.push_back(getParentExpr(tv5->axis(0), 3)); - ref_order.push_back(getParentExpr(tv9->axis(0), 3)); ref_order.push_back(getParentExpr(tv10->axis(0), 3)); + ref_order.push_back(getParentExpr(tv9->axis(0), 3)); ref_order.push_back(getParentExpr(tv2->axis(0), 2)); ref_order.push_back(getParentExpr(tv6->axis(0), 2)); - ref_order.push_back(getParentExpr(tv1->axis(0), 1)); ref_order.push_back(getParentExpr(tv8->axis(0), 2)); + ref_order.push_back(getParentExpr(tv1->axis(0), 1)); ref_order.push_back(getParentExpr(tv4->axis(0), 2)); ref_order.push_back(getParentExpr(tv5->axis(0), 2)); - ref_order.push_back(getParentExpr(tv9->axis(0), 2)); ref_order.push_back(getParentExpr(tv10->axis(0), 2)); + ref_order.push_back(getParentExpr(tv9->axis(0), 2)); ref_order.push_back(getParentExpr(tv2->axis(0), 1)); ref_order.push_back(getParentExpr(tv6->axis(0), 1)); ref_order.push_back(getParentExpr(tv8->axis(0), 1)); ref_order.push_back(getParentExpr(tv4->axis(0), 1)); ref_order.push_back(getParentExpr(tv5->axis(0), 1)); - ref_order.push_back(getParentExpr(tv9->axis(0), 1)); ref_order.push_back(getParentExpr(tv10->axis(0), 1)); + ref_order.push_back(getParentExpr(tv9->axis(0), 1)); checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); } From bef4a7880f8565aa092c0c793bd356e62c642eb6 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 5 Feb 2024 17:55:20 -0800 Subject: [PATCH 06/17] Remove popFront as it's no longer used --- csrc/disjoint_set.h | 9 --------- 1 file changed, 9 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index d53861b2087..d1af78fef7f 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -92,15 +92,6 @@ class VectorOfUniqueEntries { return any_added; } - // Remove and returns the first element in vector. Note that this is - // a costly operation as the underlying container is std::vector. - T popFront() { - T v = vector_.front(); - set_.erase(v); - vector_.erase(vector_.begin()); - return v; - } - // Returns a new VectorOfUniqueEntries with entries that are in both this and // other, order is preserved as this. VectorOfUniqueEntries computeIntersect( From 2b02316143c3c34374d347894a427454b47fb7eb Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 5 Feb 2024 18:11:01 -0800 Subject: [PATCH 07/17] Rename id to val --- csrc/val_graph_visitor.cpp | 38 +++++++++++++++++++------------------- csrc/val_graph_visitor.h | 4 ++-- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 76d701444c2..2f98bf95ab2 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -13,9 +13,9 @@ namespace nvfuser { void ValGraphVisitor::traverse() { const ValGroups terminating_inputs = graph().getTerminatingInputs(); - std::deque to_visit_ids( + std::deque to_visit_vals( terminating_inputs.begin(), terminating_inputs.end()); - ValGroups visited_ids; + ValGroups visited_vals; std::deque to_visit_exprs; ExprGroups visited_exprs; @@ -23,8 +23,8 @@ void ValGraphVisitor::traverse() { auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool { const auto inp_groups = graph().inputGroups(expr_group); return std::all_of( - inp_groups.begin(), inp_groups.end(), [&](ValGroup id_group) { - return visited_ids.has(id_group) || id_group->empty(); + inp_groups.begin(), inp_groups.end(), [&](ValGroup val_group) { + return visited_vals.has(val_group) || val_group->empty(); }); }; @@ -57,7 +57,7 @@ void ValGraphVisitor::traverse() { }); }; - while (!to_visit_ids.empty() || !to_visit_exprs.empty()) { + while (!to_visit_vals.empty() || !to_visit_exprs.empty()) { // Process expressions first as all definitions of vals have to be // processed before we can process that val. @@ -82,7 +82,7 @@ void ValGraphVisitor::traverse() { for (const ValGroup& output_group : graph().outputGroups(current_expr_group)) { - to_visit_ids.push_back(output_group); + to_visit_vals.push_back(output_group); } } else { still_to_visit_exprs.push_back(current_expr_group); @@ -91,34 +91,34 @@ void ValGraphVisitor::traverse() { std::swap(to_visit_exprs, still_to_visit_exprs); - std::deque still_to_visit_ids; - while (!to_visit_ids.empty()) { - auto current_id_group = to_visit_ids.front(); - to_visit_ids.pop_front(); - NVF_ERROR(!current_id_group->empty()); - if (visited_ids.has(current_id_group)) { + std::deque still_to_visit_vals; + while (!to_visit_vals.empty()) { + auto current_val_group = to_visit_vals.front(); + to_visit_vals.pop_front(); + NVF_ERROR(!current_val_group->empty()); + if (visited_vals.has(current_val_group)) { continue; } - if (is_val_ready(current_id_group)) { - handle(current_id_group); + if (is_val_ready(current_val_group)) { + handle(current_val_group); something_was_processed = true; - visited_ids.pushBack(current_id_group); + visited_vals.pushBack(current_val_group); - for (const ExprGroup& use_group : graph().getUses(current_id_group)) { + for (const ExprGroup& use_group : graph().getUses(current_val_group)) { to_visit_exprs.push_back(use_group); } } else { - still_to_visit_ids.push_back(current_id_group); + still_to_visit_vals.push_back(current_val_group); } } - std::swap(to_visit_ids, still_to_visit_ids); + std::swap(to_visit_vals, still_to_visit_vals); NVF_ERROR( something_was_processed || - (to_visit_ids.empty() && to_visit_exprs.empty()), + (to_visit_vals.empty() && to_visit_exprs.empty()), "Infinite loop entered."); } } diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h index 391c08902f3..54a5ed253b8 100644 --- a/csrc/val_graph_visitor.h +++ b/csrc/val_graph_visitor.h @@ -71,7 +71,7 @@ class ValGraphVisitor { ValGraphVisitor(ValGraphVisitor&& other) = default; - virtual void handle(const ValGroup& id_group) = 0; + virtual void handle(const ValGroup& val_group) = 0; virtual void handle(const ExprGroup& expr_group) = 0; void traverse(); @@ -92,7 +92,7 @@ class ValGraphStmtSort : public ValGraphVisitor { } // Return non-reference so that code like below can work - // for (auto expr_group: IdGraphStmtSort(graph).exprs()) + // for (auto expr_group: ValGraphStmtSort(graph).exprs()) ExprGroups exprs() const { return sorted_exprs_; } From 8accbd94f71a264791b1c1f0c5cbeda9b208a366 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 5 Feb 2024 18:13:33 -0800 Subject: [PATCH 08/17] Remove accidentally added files --- csrc/id_model/visitor.cpp | 162 -------------------------------------- csrc/id_model/visitor.h | 97 ----------------------- 2 files changed, 259 deletions(-) delete mode 100644 csrc/id_model/visitor.cpp delete mode 100644 csrc/id_model/visitor.h diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp deleted file mode 100644 index f81ce4177d9..00000000000 --- a/csrc/id_model/visitor.cpp +++ /dev/null @@ -1,162 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include - -namespace nvfuser { - -void IdGraphVisitor::traverse() { - ValGroups all_ids; - ExprGroups all_exprs; - { - // Initialize IDs to traverse. If sub_selection is provided, only - // traverse IDs that are included in the set are traversed. - if (sub_selection_.empty()) { - all_ids = ValGroups( - graph().disjointValSets().disjointSets().begin(), - graph().disjointValSets().disjointSets().end()); - } else { - for (auto id : sub_selection_) { - if (graph().hasGroup(id)) { - all_ids.pushBack(graph().toGroup(id)); - } - } - } - - // Initialize exprs to traverse. If sub_selection is provided, - // only traverse exprs that are strictly contained within the provided - // sub_selection. Exprs are excluded if any of inputs or outputs - // is not in sub_selection. - if (sub_selection_.empty()) { - all_exprs = ExprGroups( - graph().disjointExprSets().disjointSets().begin(), - graph().disjointExprSets().disjointSets().end()); - } else { - for (const ValGroup& id_group : all_ids) { - for (const ExprGroup& def : graph().getDefinitions(id_group)) { - if (all_exprs.has(def)) { - continue; - } - auto inp_groups = ValGroups(graph().inputGroups(def)); - auto out_groups = ValGroups(graph().outputGroups(def)); - if (inp_groups.computeSubtract(all_ids).empty() && - out_groups.computeSubtract(all_ids).empty()) { - all_exprs.pushBack(def); - } - } - } - } - } - // There could be IterDomains in from or to that are between other from and - // to nodes. Make sure to clear those out. - ValGroups terminating_inputs; - ValGroups terminating_outputs; - - { - ValGroups not_inputs; - ValGroups not_outputs; - for (const ExprGroup& expr_group : all_exprs) { - if (graph().isTrivialExprGroup(expr_group)) { - // Expression is just a loop to its current group, ignore - continue; - } - - not_inputs.pushBack(graph().outputGroups(expr_group)); - not_outputs.pushBack(graph().inputGroups(expr_group)); - } - - terminating_inputs = - ValGroups(all_ids.begin(), all_ids.end()).computeSubtract(not_inputs); - - terminating_outputs = - ValGroups(all_ids.begin(), all_ids.end()).computeSubtract(not_outputs); - } - - ValGroups to_visit_ids = terminating_inputs; - ValGroups visited_ids; - - ExprGroups to_visit_exprs; - ExprGroups visited_exprs; - - auto is_expr_ready = [&](const ExprGroup& expr_group) { - auto inp_groups = graph().inputGroups(expr_group); - return std::all_of( - inp_groups.begin(), inp_groups.end(), [&](ValGroup id_group) { - return visited_ids.has(id_group) || id_group->empty(); - }); - }; - - auto is_id_ready = [&](const ValGroup& id_group) { - const ExprGroups& unique_defs = graph().getDefinitions(id_group); - return std::all_of( - unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { - return expr_group->empty() || visited_exprs.has(expr_group) || - graph().isTrivialExprGroup(expr_group); - }); - }; - - while (!to_visit_ids.empty() || !to_visit_exprs.empty()) { - // Process expressions first as all definitions of iter domains have to be - // processed before we can process that iter domain. - - // Detect if nothing has been processed which would put us in an infinite - // loop - bool something_was_processed = false; - ExprGroups still_to_visit_exprs; - - while (!to_visit_exprs.empty()) { - ExprGroup current_expr_group = to_visit_exprs.popFront(); - NVF_ERROR(!current_expr_group->empty()); - if (visited_exprs.has(current_expr_group)) { - continue; - } - - if (is_expr_ready(current_expr_group)) { - handle(current_expr_group); - - something_was_processed = true; - visited_exprs.pushBack(current_expr_group); - - to_visit_ids.pushBack(graph().outputGroups(current_expr_group)); - } else { - still_to_visit_exprs.pushBack(current_expr_group); - } - } - - std::swap(to_visit_exprs, still_to_visit_exprs); - - ValGroups still_to_visit_ids; - while (!to_visit_ids.empty()) { - auto current_id_group = to_visit_ids.popFront(); - NVF_ERROR(!current_id_group->empty()); - if (visited_ids.has(current_id_group)) { - continue; - } - - if (is_id_ready(current_id_group)) { - handle(current_id_group); - - something_was_processed = true; - visited_ids.pushBack(current_id_group); - - if (!terminating_outputs.has(current_id_group)) { - const ExprGroups& uses = graph().getUses(current_id_group); - to_visit_exprs.pushBack(uses); - } - } else { - still_to_visit_ids.pushBack(current_id_group); - } - } - std::swap(to_visit_ids, still_to_visit_ids); - - NVF_ERROR( - something_was_processed || - (to_visit_ids.empty() && to_visit_exprs.empty()), - "Infinite loop entered."); - } -} -} // namespace nvfuser diff --git a/csrc/id_model/visitor.h b/csrc/id_model/visitor.h deleted file mode 100644 index 2c13b9efae0..00000000000 --- a/csrc/id_model/visitor.h +++ /dev/null @@ -1,97 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include -#include -#include - -namespace nvfuser { - -// Iterates through an IterDomain Graph in topological order, calling handle on -// all Id and all Expr groups in a forward topological order. -// -// Warning: Expr groups that have an input and output in the same ValGroup are -// ignored. -// -// Warning: This is not a great iterator if there's a desire to minimize paths -// traveled to simply visit all ValGroups in order. See ExprsBetween to see how -// we might minimize paths. -class IdGraphVisitor { - public: - IdGraphVisitor() = delete; - - IdGraphVisitor& operator=(const IdGraphVisitor& other) = delete; - - IdGraphVisitor& operator=(IdGraphVisitor&& other) = delete; - - virtual ~IdGraphVisitor() = default; - - protected: - // If sub_selection is assumed to be a set of iter domains by which form a - // sub-regrion of the IdGraph provided. Only that sub-region will be visited. - IdGraphVisitor( - const ValGraph& id_graph, - const VectorOfUniqueEntries sub_selection = {}) - : id_graph_(id_graph), sub_selection_(sub_selection) {} - - IdGraphVisitor(const IdGraphVisitor& other) = default; - - IdGraphVisitor(IdGraphVisitor&& other) = default; - - virtual void handle(ValGroup id_group) = 0; - virtual void handle(ExprGroup expr_group) = 0; - - void traverse(); - - const ValGraph& graph() { - return id_graph_; - }; - - private: - const ValGraph& id_graph_; - const VectorOfUniqueEntries sub_selection_; -}; - -// Statement sorting based on IdGraphVisitor, see warnings to IdGraph Visitor. -class IdGraphStmtSort : public IdGraphVisitor { - public: - IdGraphStmtSort( - const ValGraph& id_graph, - const VectorOfUniqueEntries sub_selection = {}) - : IdGraphVisitor(id_graph, sub_selection) { - IdGraphVisitor::traverse(); - } - - // Return non-reference so that code like below can work - // for (auto expr_group: IdGraphStmtSort(graph).exprs()) - ExprGroups exprs() const { - return sorted_exprs_; - } - - ValGroups ids() const { - return sorted_ids_; - } - - ~IdGraphStmtSort() override = default; - - protected: - using IdGraphVisitor::handle; - void handle(ValGroup id_group) override { - sorted_ids_.pushBack(id_group); - } - - void handle(ExprGroup expr_group) override { - sorted_exprs_.pushBack(expr_group); - } - - ExprGroups sorted_exprs_; - ValGroups sorted_ids_; -}; - -} // namespace nvfuser From e33a005b83c83ba2e94f94c78f36361c717d7b08 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Feb 2024 14:31:03 -0800 Subject: [PATCH 09/17] Refactor traversal loop --- csrc/val_graph_visitor.cpp | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 2f98bf95ab2..cc6604cb582 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -57,15 +57,17 @@ void ValGraphVisitor::traverse() { }); }; - while (!to_visit_vals.empty() || !to_visit_exprs.empty()) { + // Detect if nothing has been processed which would put us in an infinite + // loop + bool something_was_processed = false; + + do { + something_was_processed = false; + // Process expressions first as all definitions of vals have to be // processed before we can process that val. - // Detect if nothing has been processed which would put us in an infinite - // loop - bool something_was_processed = false; std::deque still_to_visit_exprs; - while (!to_visit_exprs.empty()) { ExprGroup current_expr_group = to_visit_exprs.front(); to_visit_exprs.pop_front(); @@ -116,10 +118,24 @@ void ValGraphVisitor::traverse() { std::swap(to_visit_vals, still_to_visit_vals); - NVF_ERROR( - something_was_processed || - (to_visit_vals.empty() && to_visit_exprs.empty()), - "Infinite loop entered."); + } while (something_was_processed); + + if (!to_visit_vals.empty()) { + std::stringstream ss; + ss << "Remaining Vals to visit:"; + for (const ValGroup& vg : to_visit_vals) { + ss << " " << nvfuser::toString(vg); + } + NVF_ERROR(false, ss.str()); + } + + if (!to_visit_exprs.empty()) { + std::stringstream ss; + ss << "Remaining Exprs to visit:"; + for (const ExprGroup& eg : to_visit_exprs) { + ss << " " << nvfuser::toString(eg); + } + NVF_ERROR(false, ss.str()); } } From b92ec79cfa013dbc8512afc1e0157f6c36f723cd Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Feb 2024 15:18:05 -0800 Subject: [PATCH 10/17] cleanup --- csrc/val_graph_visitor.cpp | 5 ----- test/test_id_model.cpp | 32 ++++++++++++++++---------------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index cc6604cb582..9d7604d317b 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -67,7 +67,6 @@ void ValGraphVisitor::traverse() { // Process expressions first as all definitions of vals have to be // processed before we can process that val. - std::deque still_to_visit_exprs; while (!to_visit_exprs.empty()) { ExprGroup current_expr_group = to_visit_exprs.front(); to_visit_exprs.pop_front(); @@ -86,13 +85,9 @@ void ValGraphVisitor::traverse() { graph().outputGroups(current_expr_group)) { to_visit_vals.push_back(output_group); } - } else { - still_to_visit_exprs.push_back(current_expr_group); } } - std::swap(to_visit_exprs, still_to_visit_exprs); - std::deque still_to_visit_vals; while (!to_visit_vals.empty()) { auto current_val_group = to_visit_vals.front(); diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 39dae8fce53..8687ff3e34b 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -336,24 +336,24 @@ TEST_F(IdModelTest, ValGraphStmtSort4) { // exprg{51 63}: Merge iS15 iS16 // exprg{69 73}: Split iS1 // exprg{9 25 33 45}: Merge iS20 iS21 - // exprg{27}: Merge iS35 bS10 - // exprg{11}: Merge iS23 iS22 - // exprg{19}: Merge iS29 iS19 // exprg{41}: Split iS46 // exprg{59}: Split iS61 + // exprg{19}: Merge iS29 iS19 // exprg{53 65}: Split iS56 // exprg{71 75}: Split iS71 + // exprg{11}: Merge iS23 iS22 + // exprg{27}: Merge iS35 bS10 // exprg{35 47}: Split iS41 - // exprg{29}: Split iS36 - // exprg{13}: Split iS24 - // exprg{21}: Split iS30 // exprg{43}: Split iS47 // exprg{61}: Split iS62 + // exprg{21}: Split iS30 // exprg{55 67}: Split iS57 + // exprg{13}: Split iS24 + // exprg{29}: Split iS36 // exprg{37 49}: Split iS42 - // exprg{31}: Split iS37 - // exprg{15}: Split iS25 // exprg{23}: Split iS31 + // exprg{15}: Split iS25 + // exprg{31}: Split iS37 std::vector ref_order; ref_order.push_back(getParentExpr(tv2->axis(0), 3)); @@ -362,24 +362,24 @@ TEST_F(IdModelTest, ValGraphStmtSort4) { ref_order.push_back(getParentExpr(tv8->axis(0), 3)); ref_order.push_back(getParentExpr(tv1->axis(0), 2)); ref_order.push_back(getParentExpr(tv10->axis(0), 4)); - ref_order.push_back(getParentExpr(tv5->axis(0), 3)); - ref_order.push_back(getParentExpr(tv10->axis(0), 3)); - ref_order.push_back(getParentExpr(tv9->axis(0), 3)); ref_order.push_back(getParentExpr(tv2->axis(0), 2)); ref_order.push_back(getParentExpr(tv6->axis(0), 2)); + ref_order.push_back(getParentExpr(tv9->axis(0), 3)); ref_order.push_back(getParentExpr(tv8->axis(0), 2)); ref_order.push_back(getParentExpr(tv1->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 3)); + ref_order.push_back(getParentExpr(tv5->axis(0), 3)); ref_order.push_back(getParentExpr(tv4->axis(0), 2)); - ref_order.push_back(getParentExpr(tv5->axis(0), 2)); - ref_order.push_back(getParentExpr(tv10->axis(0), 2)); - ref_order.push_back(getParentExpr(tv9->axis(0), 2)); ref_order.push_back(getParentExpr(tv2->axis(0), 1)); ref_order.push_back(getParentExpr(tv6->axis(0), 1)); + ref_order.push_back(getParentExpr(tv9->axis(0), 2)); ref_order.push_back(getParentExpr(tv8->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 2)); + ref_order.push_back(getParentExpr(tv5->axis(0), 2)); ref_order.push_back(getParentExpr(tv4->axis(0), 1)); - ref_order.push_back(getParentExpr(tv5->axis(0), 1)); - ref_order.push_back(getParentExpr(tv10->axis(0), 1)); ref_order.push_back(getParentExpr(tv9->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 1)); + ref_order.push_back(getParentExpr(tv5->axis(0), 1)); checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); } From 0059992c25031915a8f6f8a9c4390a7044f41537 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Feb 2024 15:19:20 -0800 Subject: [PATCH 11/17] Update csrc/val_graph_visitor.cpp Co-authored-by: Jingyue Wu --- csrc/val_graph_visitor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 9d7604d317b..189eb7870c7 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -117,7 +117,7 @@ void ValGraphVisitor::traverse() { if (!to_visit_vals.empty()) { std::stringstream ss; - ss << "Remaining Vals to visit:"; + ss << "The graph has an infinite loop. The following Vals should be visited but are never ready:"; for (const ValGroup& vg : to_visit_vals) { ss << " " << nvfuser::toString(vg); } From cfc784e421cc8b913902b5e2195a9ddec2493608 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Feb 2024 15:19:28 -0800 Subject: [PATCH 12/17] Update csrc/val_graph_visitor.cpp Co-authored-by: Jingyue Wu --- csrc/val_graph_visitor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 189eb7870c7..79a406a18b3 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -126,7 +126,7 @@ void ValGraphVisitor::traverse() { if (!to_visit_exprs.empty()) { std::stringstream ss; - ss << "Remaining Exprs to visit:"; + ss << "The graph has an infinite loop. The following Exprs should be visited but are never ready:"; for (const ExprGroup& eg : to_visit_exprs) { ss << " " << nvfuser::toString(eg); } From b891e366681eda366d4042d7493a80724d4440b9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 8 Feb 2024 17:16:56 -0800 Subject: [PATCH 13/17] handle trivial expr --- csrc/val_graph_visitor.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 79a406a18b3..202934049a7 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -51,9 +51,22 @@ void ValGraphVisitor::traverse() { const ExprGroups& unique_defs = graph().getDefinitions(val_group); return std::all_of( unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { - return expr_group->empty() || visited_exprs.has(expr_group) || - terminating_inputs.has(val_group) || - graph().isTrivialExprGroup(expr_group); + if (expr_group->empty() || visited_exprs.has(expr_group) || + terminating_inputs.has(val_group)) { + return true; + } + // Handle trivial expr groups. This expr_group is not + // visited yet, which means there're input ValGroups that + // are not yet visited. If those not-visited inputs are + // actually the same as val_group, visit val_group at this + // point to resolve the circular dependency. + for (const ValGroup& input_group : graph().inputGroups(expr_group)) { + if (input_group != val_group && !visited_vals.has(input_group) && + input_group->empty()) { + return false; + } + } + return true; }); }; From d1ca51216547554ebe7334d9213002eeb914d3ac Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 12 Feb 2024 11:58:21 -0800 Subject: [PATCH 14/17] Update csrc/val_graph_visitor.cpp Co-authored-by: Jingyue Wu --- csrc/val_graph_visitor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 202934049a7..930089d39c3 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -55,7 +55,7 @@ void ValGraphVisitor::traverse() { terminating_inputs.has(val_group)) { return true; } - // Handle trivial expr groups. This expr_group is not + // Handle ExprGroups that return one or some of its input ValGroups as output. This expr_group is not // visited yet, which means there're input ValGroups that // are not yet visited. If those not-visited inputs are // actually the same as val_group, visit val_group at this From 117513a3531a4bb121c3a613bb8de58a33101a49 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 12 Feb 2024 12:15:10 -0800 Subject: [PATCH 15/17] clang-format --- csrc/val_graph_visitor.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 930089d39c3..59fc54850e8 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -55,10 +55,10 @@ void ValGraphVisitor::traverse() { terminating_inputs.has(val_group)) { return true; } - // Handle ExprGroups that return one or some of its input ValGroups as output. This expr_group is not - // visited yet, which means there're input ValGroups that - // are not yet visited. If those not-visited inputs are - // actually the same as val_group, visit val_group at this + // Handle ExprGroups that return one or some of its input ValGroups as + // output. This expr_group is not visited yet, which means there're + // input ValGroups that are not yet visited. If those not-visited + // inputs are actually the same as val_group, visit val_group at this // point to resolve the circular dependency. for (const ValGroup& input_group : graph().inputGroups(expr_group)) { if (input_group != val_group && !visited_vals.has(input_group) && From a66bdb9d431913604e3c96b4a30a5f31ae1af84d Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 12 Feb 2024 12:15:18 -0800 Subject: [PATCH 16/17] typo --- test/test_id_model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 8687ff3e34b..d101ca6cabe 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -278,7 +278,7 @@ TEST_F(IdModelTest, ValGraphStmtSort3) { auto tv4 = set(tv3); fusion.addOutput(tv4); - // Merge adn split by one. The split input and output will be mapped. + // Merge and split by one. The split input and output will be mapped. for (auto tv : {tv0, tv1, tv2}) { tv->merge(0)->split(0, 1); } From 465146ddd41afa8fee33a6d575dd4cd1bb412195 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 12 Feb 2024 14:57:42 -0800 Subject: [PATCH 17/17] cleanup --- csrc/val_graph.cpp | 6 ------ csrc/val_graph.h | 4 ---- csrc/val_graph_visitor.cpp | 6 ++++-- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index fc26db00bbd..f5578a76648 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -542,12 +542,6 @@ bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return true; } -bool ValGraph::isTrivialExprGroup(const ExprGroup& expr_group) const { - return !ValGroups(inputGroups(expr_group)) - .computeIntersect(ValGroups(outputGroups(expr_group))) - .empty(); -} - void ValGraph::validateConsistency() const { // Check the consistency of the mapping information. Specifically: // 1. All ValGroup and ExprGroup sets are not empty. This may not be diff --git a/csrc/val_graph.h b/csrc/val_graph.h index 0c306f72e44..72182e3f2f8 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -202,10 +202,6 @@ class ValGraph { // be the only call in ValGraph to mapThroughExpr. void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); - // Returns if the expression group has an input id group that matches an - // output id group. - bool isTrivialExprGroup(const ExprGroup& expr_group) const; - // Can't back prop through merge without making sure one input actually // matches. This can be done on a map or extent basis. // TODO: Move this to val_graph.cpp once validation_utils.cpp is diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 59fc54850e8..7d6646c747e 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -48,11 +48,13 @@ void ValGraphVisitor::traverse() { // // See also IdModelTest.ValGraphStmtSort3 for a concrete example. auto is_val_ready = [&](const ValGroup& val_group) -> bool { + if (terminating_inputs.has(val_group)) { + return true; + } const ExprGroups& unique_defs = graph().getDefinitions(val_group); return std::all_of( unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { - if (expr_group->empty() || visited_exprs.has(expr_group) || - terminating_inputs.has(val_group)) { + if (expr_group->empty() || visited_exprs.has(expr_group)) { return true; } // Handle ExprGroups that return one or some of its input ValGroups as