diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index a4bc85905f5e..f9d1b9734d45 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -157,6 +157,7 @@ def __init__(self, module): self._get_output = module["get_output"] self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] + self._get_input_index = module["get_input_index"] self._get_num_inputs = module["get_num_inputs"] self._load_params = module["load_params"] self._share_params = module["share_params"] @@ -242,6 +243,21 @@ def get_input(self, index, out=None): return self._get_input(index) + def get_input_index(self, name): + """Get inputs index via input name. + + Parameters + ---------- + name : str + The input key name + + Returns + ------- + index: int + The input index. -1 will be returned if the given input name is not found. + """ + return self._get_input_index(name) + def get_output(self, index, out=None): """Get index-th output to out diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 7aae12b32377..bc73a5988377 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -502,6 +502,11 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, dmlc::MemoryStringStream strm(const_cast(¶m_blob)); this->ShareParams(dynamic_cast(*module.operator->()), &strm); }); + } else if (name == "get_input_index") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string"; + *rv = this->GetInputIndex(args[0].operator String()); + }); } else { return PackedFunc(); } diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index 234095f67864..7beac197fb3a 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -311,5 +311,19 @@ def test_graph_executor_nested_tuples(): tvm.testing.assert_allclose(out[1][1][1].numpy(), data[3]) +def test_graph_executor_api(): + dname_0, dname_1 = "data_0", "data_1" + data_0, data_1 = [relay.var(c, shape=(1, 1), dtype="float32") for c in [dname_0, dname_1]] + net = relay.add(data_0, data_1) + func = relay.Function((data_0, data_1), net) + + lib = relay.build(tvm.IRModule.from_expr(func), "llvm") + mod = graph_executor.GraphModule(lib["default"](tvm.cpu(0))) + + assert mod.get_input_index(dname_1) == 1 + assert mod.get_input_index(dname_0) == 0 + assert mod.get_input_index("Invalid") == -1 + + if __name__ == "__main__": pytest.main([__file__])