From 249e74be5892b2143dc4243d881a85985c92cd65 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 7 Oct 2019 21:52:00 +0000 Subject: [PATCH 1/3] Fix VM invoke with set_params --- src/runtime/vm/vm.cc | 44 +++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 02ea3a42b156..e8ae4278ee33 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -577,31 +577,33 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string func_name = args[0]; + auto gvit = this->global_map.find(func_name); + CHECK(gvit != this->global_map.end()) << "Cannot find function " << func_name; + auto func_index = gvit->second; + auto const& vm_func = this->functions[func_index]; + auto const& param_names = vm_func.params; auto ctx = this->GetParamsContext(); - std::vector func_args; + + // Prepare the func args + std::vector func_args(param_names.size()); + std::vector empty_slots; + + for (size_t i = 0; i < param_names.size(); ++i) { + const auto& pit = params_.find(param_names[i]); + if (pit != params_.end()) { + func_args[i] = pit->second; + } else { + empty_slots.push_back(i); + } + } + CHECK_EQ(empty_slots.size(), args.size() - 1) + << "The number of provided parameters doesn't match the number of arguments"; for (int i = 1; i < args.size(); ++i) { Object obj = CopyTo(args[i], ctx); - func_args.push_back(obj); - } - auto it = std::find_if(functions.begin(), functions.end(), - [func_name](const VMFunction& func) { - return func.name == func_name; - }); - - CHECK(it != functions.end()) << "Cannot find function " << func_name << "\n"; - CHECK_EQ(func_args.size() + params_.size(), it->params.size()) - << "The number of provided parameters doesn't match the number of arguments" - << "\n"; - if (!params_.empty()) { - for (const auto& p : it->params) { - const auto& pit = params_.find(p); - if (pit != params_.end()) { - func_args.push_back(pit->second); - } - } - CHECK_EQ(func_args.size(), it->params.size()); + func_args[empty_slots[i - 1]] = obj; } - *rv = this->Invoke(func_name, func_args); + + *rv = this->Invoke(vm_func, func_args); }); } else if (name == "init") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { From 4af2f9617c97c104e2f7dcb3c1583537dba7921f Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 7 Oct 2019 23:58:45 +0000 Subject: [PATCH 2/3] add test --- src/runtime/vm/vm.cc | 2 +- tests/python/relay/test_vm.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index e8ae4278ee33..86ccdf48ef15 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -602,7 +602,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, Object obj = CopyTo(args[i], ctx); func_args[empty_slots[i - 1]] = obj; } - + *rv = this->Invoke(vm_func, func_args); }); } else if (name == "init") { diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 593d7ac64cf8..f60c53317407 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -575,6 +575,27 @@ def test_add_op_broadcast(): mod["main"] = func check_result([x_data, y_data], x_data + y_data, mod=mod) +def test_set_params(): + mod = relay.Module() + x = relay.var('x', shape=(10, 5)) + w = relay.var('w', shape=(6, 5)) + b = relay.var('b', shape=(6,)) + y = relay.nn.bias_add(relay.nn.dense(x, w), b) + mod["main"] = relay.Function([x, w, b], y) + compiler = relay.vm.VMCompiler() + vm = compiler.compile(mod, 'llvm') + vm.init(tvm.cpu()) + + x_np = np.random.uniform(size=(10, 5)).astype('float32') + w_np = np.random.uniform(size=(6, 5)).astype('float32') + b_np = np.random.uniform(size=(6,)).astype('float32') + ref_np = np.dot(x_np, w_np.T) + b_np + params = {'w': w_np} + vm.load_params(params) + out = vm.run(x_np, b_np) + tvm.testing.assert_allclose(out.asnumpy(), ref_np) + + if __name__ == "__main__": test_id() test_op() @@ -608,3 +629,4 @@ def test_add_op_broadcast(): test_add_op_scalar() test_add_op_tensor() test_add_op_broadcast() + test_set_params() From bc21d42e0ffeccd9515782034cb634e00f9bf217 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 8 Oct 2019 04:24:05 +0000 Subject: [PATCH 3/3] tweak --- src/runtime/vm/vm.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 86ccdf48ef15..ed12d77d80a8 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -580,8 +580,8 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, auto gvit = this->global_map.find(func_name); CHECK(gvit != this->global_map.end()) << "Cannot find function " << func_name; auto func_index = gvit->second; - auto const& vm_func = this->functions[func_index]; - auto const& param_names = vm_func.params; + const auto& vm_func = this->functions[func_index]; + const auto& param_names = vm_func.params; auto ctx = this->GetParamsContext(); // Prepare the func args