From 2d6104ff4cf0aa0539f717e97236e160b6cc9e8d Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 10 Apr 2020 00:13:16 +0000 Subject: [PATCH 01/10] Fix duplicate output in partitiongraph --- src/relay/analysis/annotated_region_set.cc | 8 +++++++- src/relay/transforms/partition_graph.cc | 9 ++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 94c7621e60af..53b58d494365 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -131,7 +131,13 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { CHECK_EQ(region->GetTarget(), target); } region->nodes_.insert(GetRef(call)); - region->outs_.push_back(GetRef(call)); + if (!std::any_of(region->outs_.begin(), region->outs_.end(), + [call](Expr& out) { + return Downcast(out)->args[0] == + GetRef(call)->args[0]; + })) { + region->outs_.push_back(GetRef(call)); + } } ExprVisitor::VisitExpr_(call); } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index c8367fb140f2..172faa0c267a 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -207,9 +207,12 @@ class Partitioner : public ExprMutator { if (region_function_calls.find(region) != region_function_calls.end()) { // This section is executed only if there are multiple outputs in the - // region Thus, the function is always created and at the end there + // region or the same output is being accessed multiple times. + // Thus, the function is always created and at the end there // would be a tuple node Therefore, we insert a tuple get item node. - + if (region->GetOutputs().size() == 1) { + return region_function_calls[region]; + } // Use the already created tuple node auto sg_call = region_function_calls[region]; int index = GetRetIdx(region, GetRef(call)); @@ -462,7 +465,7 @@ class Partitioner : public ExprMutator { int GetRetIdx(AnnotatedRegion sg, const Expr& arg) { int idx = 0; for (auto arg_ : sg->GetOutputs()) { - if (arg == arg_) { + if (Downcast(arg)->args[0] == Downcast(arg_)->args[0]) { return idx; } idx++; From f652d63d0d00d340a262c874525ae440ec398fc2 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 13 Apr 2020 19:09:47 +0000 Subject: [PATCH 02/10] Add test case --- .../python/relay/test_pass_partition_graph.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 2ee8538e30ed..08ac4e45a54e 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -23,6 +23,7 @@ import tvm import tvm.relay.testing +import tvm.relay.op as reg from tvm import relay from tvm import runtime from tvm.relay import transform @@ -1036,6 +1037,61 @@ def test_different_output_region(): test_same_output_region() test_different_output_region() +def test_duplicate_outputs(): + target = "test_duplicate_outputs" + + @reg.register("abs", "target." + target) + def abs(attrs, args): # pylint: disable=unused-variable + return True + + def create_graph(): + data = relay.var('data', shape=(10, 10)) + x = relay.abs(data) + out_1 = relay.nn.relu(x) + out_2 = relay.tanh(x) + out_3 = relay.log(x) + out = relay.Tuple([out_1, out_2, out_3]) + func = relay.Function([data], out) + return func + + def expected(): + mod = tvm.IRModule() + + # function 0 + f0_i0 = relay.var(target+"_0_i0", shape=(10, 10)) + f0_o0 = relay.abs(f0_i0) + func0 = relay.Function([f0_i0], f0_o0) + + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", target) + func0 = func0.with_attr("global_symbol", target+"_0") + gv0 = relay.GlobalVar(target+"_0") + mod[gv0] = func0 + + # body + data = relay.var('data', shape=(10, 10)) + function_out = gv0(data) + out_1 = relay.nn.relu(function_out) + out_2 = relay.tanh(function_out) + out_3 = relay.log(function_out) + out = relay.Tuple([out_1, out_2, out_3]) + func = relay.Function([data], out) + mod["main"] = func + return mod + + mod = tvm.IRModule() + mod["main"] = create_graph() + + seq = transform.Sequential([ + transform.AnnotateTarget(target), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ]) + + ref_mod = expected() + partitioned = seq(mod) + assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) if __name__ == "__main__": test_multi_node_compiler() @@ -1051,3 +1107,4 @@ def test_different_output_region(): test_mixed_single_multiple_outputs() test_dnnl_fuse() test_multiple_use_of_an_output() + test_duplicate_outputs() From d2707b5320f8ac67de11c3f52acc26f62e8566eb Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 13 Apr 2020 20:12:00 +0000 Subject: [PATCH 03/10] Fix test_annotated_regions with duplicate compiler_end outputs --- tests/python/relay/test_annotated_regions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_annotated_regions.py b/tests/python/relay/test_annotated_regions.py index f3c157d296db..21d95a6bd860 100644 --- a/tests/python/relay/test_annotated_regions.py +++ b/tests/python/relay/test_annotated_regions.py @@ -56,7 +56,7 @@ def test_region_set_creator_diamond(): 'test_target', [cb_1], [cb_1, O_1, ce_1, ce_2], - [ce_1, ce_2], + [ce_1], ) check_region( region_set, From 0af617efd0fb99dd23953d8c3a80ea89af3fbee4 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 13 Apr 2020 21:14:54 +0000 Subject: [PATCH 04/10] Revert "Fix duplicate output in partitiongraph" This reverts commit e1f8ef3f4ca5b2aaa31ace6fa968bb50e5e4d1fa. --- src/relay/analysis/annotated_region_set.cc | 8 +------- src/relay/transforms/partition_graph.cc | 9 +++------ 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 53b58d494365..94c7621e60af 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -131,13 +131,7 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { CHECK_EQ(region->GetTarget(), target); } region->nodes_.insert(GetRef(call)); - if (!std::any_of(region->outs_.begin(), region->outs_.end(), - [call](Expr& out) { - return Downcast(out)->args[0] == - GetRef(call)->args[0]; - })) { - region->outs_.push_back(GetRef(call)); - } + region->outs_.push_back(GetRef(call)); } ExprVisitor::VisitExpr_(call); } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 172faa0c267a..c8367fb140f2 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -207,12 +207,9 @@ class Partitioner : public ExprMutator { if (region_function_calls.find(region) != region_function_calls.end()) { // This section is executed only if there are multiple outputs in the - // region or the same output is being accessed multiple times. - // Thus, the function is always created and at the end there + // region Thus, the function is always created and at the end there // would be a tuple node Therefore, we insert a tuple get item node. - if (region->GetOutputs().size() == 1) { - return region_function_calls[region]; - } + // Use the already created tuple node auto sg_call = region_function_calls[region]; int index = GetRetIdx(region, GetRef(call)); @@ -465,7 +462,7 @@ class Partitioner : public ExprMutator { int GetRetIdx(AnnotatedRegion sg, const Expr& arg) { int idx = 0; for (auto arg_ : sg->GetOutputs()) { - if (Downcast(arg)->args[0] == Downcast(arg_)->args[0]) { + if (arg == arg_) { return idx; } idx++; From f3ba5a5b6e8602ba833f7c1e31defd484076d26e Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 14 Apr 2020 18:04:30 +0000 Subject: [PATCH 05/10] Prevent duplicate outputs in Tuple in PartitionGraph --- src/relay/transforms/partition_graph.cc | 225 +++++++++++-------- tests/python/relay/test_annotated_regions.py | 2 +- 2 files changed, 127 insertions(+), 100 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index c8367fb140f2..29a712ec4a34 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -206,97 +206,16 @@ class Partitioner : public ExprMutator { // (each annotated regions) --> created function if (region_function_calls.find(region) != region_function_calls.end()) { - // This section is executed only if there are multiple outputs in the - // region Thus, the function is always created and at the end there - // would be a tuple node Therefore, we insert a tuple get item node. - - // Use the already created tuple node - auto sg_call = region_function_calls[region]; - int index = GetRetIdx(region, GetRef(call)); - CHECK_NE(index, -1); - - auto tuple_get_item_ = TupleGetItem(sg_call, index); - tuple_get_item_->checked_type_ = GetRef(call)->args[0]->checked_type_; - return std::move(tuple_get_item_); + // This section is executed if there are multiple outputs in the region + // or if the output of the function is being accessed multiple times by + // different nodes. + return GetFunctionOutput(region, GetRef(call)); } else { - // First time this region is encountered in the traversal - // Creating the function - - Array fields; - - for (auto ret : region->GetOutputs()) { - auto ret_expr = VisitExpr(Downcast(ret)->args[0]); - fields.push_back(ret_expr); - } - int index = GetRetIdx(region, GetRef(call)); - CHECK_NE(index, -1); - - Array params; - Array param_expr; - std::unordered_map params_bind; - - for (auto pair : region_args[region]) { - params.push_back(pair.first); - if (const auto* cn = pair.second.as()) { - params_bind[pair.first->name_hint()] = cn->data; - } else { - param_expr.push_back(pair.second); - } - } - - Function global_region_func; - if (region->GetOutputs().size() == 1) { - // If there are only a single output; no need to add a tuple - global_region_func = - Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); - } else { - auto tuple = Tuple(fields); - global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); - } - - std::string target = call->attrs.as()->compiler; - std::string name = target + "_" + std::to_string(region->GetID()); - - global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, - runtime::String(name)); - global_region_func = - WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); - global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, - tvm::runtime::String(target)); - global_region_func = - WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); - - // Constant propagation - if (!params_bind.empty()) { - global_region_func = backend::BindParamsByName(global_region_func, params_bind); - } - - std::string fname = name; - CHECK(!module_->ContainGlobalVar(fname)) - << "Global function " << fname << " already exists"; - // Create a global function and add it to the IRModule for the region. - // This way we lift the functions that should be handled by external - // codegen to the module scope and rely on the pass manager to prevent - // relay function level passes (i.e. simplify inference and fusion) - // optimizing it. - GlobalVar glob_func(fname); - module_->Add(glob_func, global_region_func); - - // The return type of callnode is the same as the type of the - // compiler_end node. - auto ret = Call(glob_func, param_expr); - region_function_calls[region] = ret; - - if (region->GetOutputs().size() == 1) { - // If there is only a single output; no need to add a tuplegetitem - // node - return std::move(ret); - } else { - // Add a tuplegetitem node to select this output out of many - auto tuple_get_item_ = TupleGetItem(ret, index); - tuple_get_item_->checked_type_ = GetRef(call)->args[0]->checked_type_; - return std::move(tuple_get_item_); - } + // First time this region is encountered in the traversal. + // Creating the function. + CreateFunction(region, call); + // Retrieve particular output. + return GetFunctionOutput(region, GetRef(call)); } } } @@ -456,18 +375,109 @@ class Partitioner : public ExprMutator { } /*! - * \brief Get the index of the return(output); - * this is to be used as tuplegetitem idx + * \brief This function is called first time that we encounter a compiler_end + * node to create the function for the subgraph. */ - int GetRetIdx(AnnotatedRegion sg, const Expr& arg) { - int idx = 0; - for (auto arg_ : sg->GetOutputs()) { - if (arg == arg_) { - return idx; + void CreateFunction(AnnotatedRegion region, const CallNode* call) { + // Create fields which is a unique list of outputs. Also populate + // region_return_indices_ map which maps parent of compiler_end node to + // corresponding index in fields. + Array fields; + int i = 0; + for (auto ret : region->GetOutputs()) { + auto ret_node = Downcast(ret)->args[0]; + // Don't duplicate outputs. + if (!region_return_indices_.count(region) || + !region_return_indices_[region].count(ret_node)) { + auto ret_expr = VisitExpr(ret_node); + fields.push_back(ret_expr); + region_return_indices_[region][ret_node] = i; + i++; } - idx++; } - return -1; + + Array params; + Array param_expr; + std::unordered_map params_bind; + + for (auto pair : region_args[region]) { + params.push_back(pair.first); + if (const auto* cn = pair.second.as()) { + params_bind[pair.first->name_hint()] = cn->data; + } else { + param_expr.push_back(pair.second); + } + } + + Function global_region_func; + if (fields.size() == 1) { + // If there are only a single output; no need to add a tuple + global_region_func = + Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); + } else { + auto tuple = Tuple(fields); + global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); + } + + std::string target = call->attrs.as()->compiler; + std::string name = target + "_" + std::to_string(region->GetID()); + + global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, + runtime::String(name)); + global_region_func = + WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); + global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, + tvm::runtime::String(target)); + global_region_func = + WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); + + // Constant propagation + if (!params_bind.empty()) { + global_region_func = backend::BindParamsByName(global_region_func, params_bind); + } + + std::string fname = name; + CHECK(!module_->ContainGlobalVar(fname)) + << "Global function " << fname << " already exists"; + // Create a global function and add it to the IRModule for the region. + // This way we lift the functions that should be handled by external + // codegen to the module scope and rely on the pass manager to prevent + // relay function level passes (i.e. simplify inference and fusion) + // optimizing it. + GlobalVar glob_func(fname); + module_->Add(glob_func, global_region_func); + + // The return type of callnode is the same as the type of the + // compiler_end node. + auto ret = Call(glob_func, param_expr); + region_function_calls[region] = ret; + } + + /*! + * \brief Get the return(output) of the function for compiler end node "arg". + */ + Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) { + Expr arg = Downcast(end_arg)->args[0]; + // Function has one output. + if (region_return_indices_[region].size() == 1) { + return region_function_calls[region]; + } + // Function has multiple outputs. + // Use already made TupleGetItem. + if (region_return_tuplegetitem_.count(region) && + region_return_tuplegetitem_[region].count(arg)) { + return region_return_tuplegetitem_[region][arg]; + } + // Create new TupleGetItem. + CHECK(region_return_indices_.count(region) && + region_return_indices_[region].count(arg)); + int index = region_return_indices_[region][arg]; + + auto func_call = region_function_calls[region]; + auto tuple_get_item_ = TupleGetItem(func_call, index); + tuple_get_item_->checked_type_ = arg->checked_type_; + region_return_tuplegetitem_[region][arg] = tuple_get_item_; + return tuple_get_item_; } /*! @@ -485,6 +495,23 @@ class Partitioner : public ExprMutator { std::unordered_map>, ObjectHash, ObjectEqual> region_args; + /*! + * \brief This map maintains the index of an output in the subgraph function + * for a given region. If there are multiple entries for a region, then the + * function has a tuple of multiple outputs for its return. + */ + using RegionReturnIndexMap = std::unordered_map; + std::unordered_map + region_return_indices_; + + /*! + * \brief This map holds already created TupleGetItem nodes for accessing + * outputs of a function. + */ + using RegionReturnTupleGetItemMap = std::unordered_map; + std::unordered_map + region_return_tuplegetitem_; + /*! * \brief Each region set is associated with a function in the module. * This map maintains the mapping between regionsets and the function it diff --git a/tests/python/relay/test_annotated_regions.py b/tests/python/relay/test_annotated_regions.py index 21d95a6bd860..f3c157d296db 100644 --- a/tests/python/relay/test_annotated_regions.py +++ b/tests/python/relay/test_annotated_regions.py @@ -56,7 +56,7 @@ def test_region_set_creator_diamond(): 'test_target', [cb_1], [cb_1, O_1, ce_1, ce_2], - [ce_1], + [ce_1, ce_2], ) check_region( region_set, From 73faa8890b23e33eca6cc887c19754c9f3fd67cd Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 14 Apr 2020 18:07:52 +0000 Subject: [PATCH 06/10] Fix lint --- src/relay/transforms/partition_graph.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 29a712ec4a34..094322118b0b 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -500,16 +500,16 @@ class Partitioner : public ExprMutator { * for a given region. If there are multiple entries for a region, then the * function has a tuple of multiple outputs for its return. */ - using RegionReturnIndexMap = std::unordered_map; - std::unordered_map + using RegionRetIndexMap = std::unordered_map; + std::unordered_map region_return_indices_; - + /*! * \brief This map holds already created TupleGetItem nodes for accessing * outputs of a function. */ - using RegionReturnTupleGetItemMap = std::unordered_map; - std::unordered_map + using RegionRetTupleGetItemMap = std::unordered_map; + std::unordered_map region_return_tuplegetitem_; /*! From 4baf986add4d2019d086ec9bbb29bdca6bbfaec2 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 14 Apr 2020 19:00:52 +0000 Subject: [PATCH 07/10] Add another test case for when regions are merged, and when TupleGetItem was duplicated --- .../python/relay/test_pass_partition_graph.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 08ac4e45a54e..772de4debf32 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1093,6 +1093,83 @@ def expected(): partitioned = seq(mod) assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) +def test_duplicate_merge_and_tuplegetitem(): + target = "test_duplicate_merge_and_tuplegetitem" + + @reg.register("nn.batch_norm", "target." + target) + def abs(attrs, args): # pylint: disable=unused-variable + return True + + @reg.register("nn.relu", "target." + target) + def abs(attrs, args): # pylint: disable=unused-variable + return True + + def create_graph(): + data = relay.var('data', shape=(10, 10)) + bn_gamma = relay.var("bn_gamma") + bn_beta = relay.var("bn_beta") + bn_mmean = relay.var("bn_mean") + bn_mvar = relay.var("bn_var") + x = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar) + out_1 = relay.nn.relu(x[0]) + bn_out_1 = x[1] + out_2 = relay.tanh(bn_out_1) + out_3 = relay.log(bn_out_1) + out = relay.Tuple([out_1, out_2, out_3]) + func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out) + return func + + def expected(): + mod = tvm.IRModule() + + # function 0 + f0_i0 = relay.var(target+"_1_i0", shape=(10, 10)) + f0_i1 = relay.var(target+"_1_i1") + f0_i2 = relay.var(target+"_1_i2") + f0_i3 = relay.var(target+"_1_i3") + f0_i4 = relay.var(target+"_1_i4") + f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4) + f0_n1 = f0_n0[1] + f0_n2 = relay.nn.relu(f0_n0[0]) + f0_o0 = relay.Tuple([f0_n1, f0_n2]) + func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0) + + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", target) + func0 = func0.with_attr("global_symbol", target+"_1") + gv0 = relay.GlobalVar(target+"_1") + mod[gv0] = func0 + + # body + data = relay.var('data', shape=(10, 10)) + bn_gamma = relay.var("bn_gamma") + bn_beta = relay.var("bn_beta") + bn_mmean = relay.var("bn_mean") + bn_mvar = relay.var("bn_var") + function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar) + get_out0 = relay.TupleGetItem(function_out, 0) + get_out1 = relay.TupleGetItem(function_out, 1) + out_2 = relay.tanh(get_out0) + out_3 = relay.log(get_out0) + out = relay.Tuple([get_out1, out_2, out_3]) + func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out) + mod["main"] = func + return mod + + mod = tvm.IRModule() + mod["main"] = create_graph() + + seq = transform.Sequential([ + transform.AnnotateTarget(target), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ]) + + ref_mod = expected() + partitioned = seq(mod) + assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -1108,3 +1185,4 @@ def expected(): test_dnnl_fuse() test_multiple_use_of_an_output() test_duplicate_outputs() + test_duplicate_tuplegetitem() From 6a4ce51d988c4193ee60458add8e14598cb44b75 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 14 Apr 2020 19:02:58 +0000 Subject: [PATCH 08/10] Pull GetFunctionOutput out of branch, improve description of GetFunctionOutput --- src/relay/transforms/partition_graph.cc | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 094322118b0b..e54106d39b5f 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -205,18 +205,13 @@ class Partitioner : public ExprMutator { // region_function_calls is map that maintains // (each annotated regions) --> created function - if (region_function_calls.find(region) != region_function_calls.end()) { - // This section is executed if there are multiple outputs in the region - // or if the output of the function is being accessed multiple times by - // different nodes. - return GetFunctionOutput(region, GetRef(call)); - } else { + if (region_function_calls.find(region) == region_function_calls.end()) { // First time this region is encountered in the traversal. // Creating the function. CreateFunction(region, call); - // Retrieve particular output. - return GetFunctionOutput(region, GetRef(call)); } + // Retrieve this particular output of function. + return GetFunctionOutput(region, GetRef(call)); } } @@ -454,7 +449,9 @@ class Partitioner : public ExprMutator { } /*! - * \brief Get the return(output) of the function for compiler end node "arg". + * \brief Get the return(output) of the function for compiler end node "end_arg". + * This will return either a Call (for a function with a single output) or a + * TupleGetItem (for a function with multiple outputs). */ Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) { Expr arg = Downcast(end_arg)->args[0]; From 148cacb2de636507bc505dfbc7c5b7d5e08a0534 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 15 Apr 2020 17:37:43 +0000 Subject: [PATCH 09/10] Use std::move for GetFunctionOutput. Fix typo with testcase name --- src/relay/transforms/partition_graph.cc | 2 +- tests/python/relay/test_pass_partition_graph.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index e54106d39b5f..15ad60be3a95 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -474,7 +474,7 @@ class Partitioner : public ExprMutator { auto tuple_get_item_ = TupleGetItem(func_call, index); tuple_get_item_->checked_type_ = arg->checked_type_; region_return_tuplegetitem_[region][arg] = tuple_get_item_; - return tuple_get_item_; + return std::move(tuple_get_item_); } /*! diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 772de4debf32..65d4c8f6f866 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1185,4 +1185,4 @@ def expected(): test_dnnl_fuse() test_multiple_use_of_an_output() test_duplicate_outputs() - test_duplicate_tuplegetitem() + test_duplicate_merge_and_tuplegetitem() From 569dc57da38899071382556f59abf44f55e5e4e1 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 15 Apr 2020 18:02:57 +0000 Subject: [PATCH 10/10] Use tvm.transform.Sequential --- tests/python/relay/test_pass_partition_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 65d4c8f6f866..8827fbf1b8b0 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1083,7 +1083,7 @@ def expected(): mod = tvm.IRModule() mod["main"] = create_graph() - seq = transform.Sequential([ + seq = tvm.transform.Sequential([ transform.AnnotateTarget(target), transform.MergeCompilerRegions(), transform.PartitionGraph(), @@ -1160,7 +1160,7 @@ def expected(): mod = tvm.IRModule() mod["main"] = create_graph() - seq = transform.Sequential([ + seq = tvm.transform.Sequential([ transform.AnnotateTarget(target), transform.MergeCompilerRegions(), transform.PartitionGraph(),