From 60254fde4ee8f2fc63aa64890ee98a76318da8e8 Mon Sep 17 00:00:00 2001 From: huajsj Date: Mon, 2 Aug 2021 21:10:57 -0700 Subject: [PATCH 1/5] [Runtime] Add graph_executor get_input_index API. In graph_executor use case, user can use set_input with input index to set input parameter, but there is no straight forward way to get correct index number with input name, here provide get_input_index API to do such work. --- python/tvm/contrib/graph_executor.py | 14 ++++++++++++++ src/runtime/graph_executor/graph_executor.cc | 8 ++++++++ tests/python/relay/test_backend_graph_executor.py | 14 ++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index a4bc85905f5e..f5675c30c2b5 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -157,6 +157,7 @@ def __init__(self, module): self._get_output = module["get_output"] self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] + self._get_input_index = module["get_input_index"] self._get_num_inputs = module["get_num_inputs"] self._load_params = module["load_params"] self._share_params = module["share_params"] @@ -242,6 +243,19 @@ def get_input(self, index, out=None): return self._get_input(index) + def get_input_index(self, name): + """Get inputs index via input name + Parameters + ---------- + name : str + The input key name + Returns + ------- + index: int + The input index + """ + return self._get_input_index(name) + def get_output(self, index, out=None): """Get index-th output to out diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 7aae12b32377..b118200bd3e2 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -502,6 +502,14 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, dmlc::MemoryStringStream strm(const_cast(¶m_blob)); this->ShareParams(dynamic_cast(*module.operator->()), &strm); }); + } else if (name == "get_input_index") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + *rv = this->GetInputIndex(args[0].operator String()); + } else { + *rv = args[0]; + } + }); } else { return PackedFunc(); } diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index 234095f67864..7beac197fb3a 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -311,5 +311,19 @@ def test_graph_executor_nested_tuples(): tvm.testing.assert_allclose(out[1][1][1].numpy(), data[3]) +def test_graph_executor_api(): + dname_0, dname_1 = "data_0", "data_1" + data_0, data_1 = [relay.var(c, shape=(1, 1), dtype="float32") for c in [dname_0, dname_1]] + net = relay.add(data_0, data_1) + func = relay.Function((data_0, data_1), net) + + lib = relay.build(tvm.IRModule.from_expr(func), "llvm") + mod = graph_executor.GraphModule(lib["default"](tvm.cpu(0))) + + assert mod.get_input_index(dname_1) == 1 + assert mod.get_input_index(dname_0) == 0 + assert mod.get_input_index("Invalid") == -1 + + if __name__ == "__main__": pytest.main([__file__]) From a29f57a2f43787645646e51182957cfe3d919f26 Mon Sep 17 00:00:00 2001 From: Hua Jiang Date: Tue, 3 Aug 2021 11:23:02 -0700 Subject: [PATCH 2/5] Update python/tvm/contrib/graph_executor.py Co-authored-by: Cody Yu --- python/tvm/contrib/graph_executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index f5675c30c2b5..e21d681c6b86 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -244,7 +244,8 @@ def get_input(self, index, out=None): return self._get_input(index) def get_input_index(self, name): - """Get inputs index via input name + """Get inputs index via input name. + Parameters ---------- name : str From 1d6f699e1bf8e316f26c81e852abd881cb859b96 Mon Sep 17 00:00:00 2001 From: Hua Jiang Date: Tue, 3 Aug 2021 11:23:13 -0700 Subject: [PATCH 3/5] Update python/tvm/contrib/graph_executor.py Co-authored-by: Cody Yu --- python/tvm/contrib/graph_executor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index e21d681c6b86..91a155dd1aa9 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -250,6 +250,7 @@ def get_input_index(self, name): ---------- name : str The input key name + Returns ------- index: int From a4c71b0b213d109b2685437908cebd6f2e70a158 Mon Sep 17 00:00:00 2001 From: Hua Jiang Date: Tue, 3 Aug 2021 11:23:21 -0700 Subject: [PATCH 4/5] Update src/runtime/graph_executor/graph_executor.cc Co-authored-by: Cody Yu --- src/runtime/graph_executor/graph_executor.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index b118200bd3e2..bc73a5988377 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -504,11 +504,8 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, }); } else if (name == "get_input_index") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (String::CanConvertFrom(args[0])) { - *rv = this->GetInputIndex(args[0].operator String()); - } else { - *rv = args[0]; - } + CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string"; + *rv = this->GetInputIndex(args[0].operator String()); }); } else { return PackedFunc(); From 3f6bce5336739a8c36fe0dfb37995c36ffadc404 Mon Sep 17 00:00:00 2001 From: Hua Jiang Date: Tue, 3 Aug 2021 11:23:29 -0700 Subject: [PATCH 5/5] Update python/tvm/contrib/graph_executor.py Co-authored-by: Cody Yu --- python/tvm/contrib/graph_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 91a155dd1aa9..f9d1b9734d45 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -254,7 +254,7 @@ def get_input_index(self, name): Returns ------- index: int - The input index + The input index. -1 will be returned if the given input name is not found. """ return self._get_input_index(name)