diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 5a72a99fa635..6fa91832a731 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -226,6 +226,16 @@ class TVM_DLL VirtualMachine : public runtime::ModuleNode { */ ObjectRef Invoke(const std::string& name, const std::vector& args); + /*! + * \brief Invoke a VM function. + * \param func The function. + * \param input_args The input arguments to the function. + * \param output_args The pre-allocated output arguments of the function. + * \return The object(s) representing the result. + */ + ObjectRef Invoke(const VMFunction& func, const std::vector& input_args, + const std::vector& output_args); + /*! * \brief Invoke a PackedFunction * @@ -249,7 +259,7 @@ class TVM_DLL VirtualMachine : public runtime::ModuleNode { const std::vector& alloc_types); /*! \brief Run VM dispatch loop. */ - void RunLoop(); + void RunLoop(const std::vector& output_tensor_reg_indices = {}); /*! \brief Get device from the device list based on a given device index. */ Device GetDevice(Index device_index) const; @@ -281,6 +291,32 @@ class TVM_DLL VirtualMachine : public runtime::ModuleNode { */ void SetOneInput(std::string name, const TVMArgValue& tag, const TVMArgValue& tensor); + /*! + * \brief Set pre-allocated output tensors to a function. + * It is native implementation of 'set_outputs' python method. + * It is used in scenario when output tensors are allocated outside each invocation. + * Note: it sets set_outputs_enabled_[name] true and fill outputs_[name] + * but after invocation the first is switched off and the second is cleared + * \param name The function name + * \param args outputs to the function. + */ + void SetOutputs(std::string name, TVMArgs args); + + /*! + * \brief Preparation part of Invoke method before RunLoop. + * \param func the function. + * \param args input args + */ + void PrintInfoAndSetInputArgs(const VMFunction& func, const std::vector& args); + + /*! + * \brief Set pre-allocated outputs to register for specified function. + * \param func_name The function's name. + * \param outputs set of output tensors. + */ + void SetOutputTensorsToRegister(const std::string& func_name, + const std::vector& outputs); + /*! * \brief Internal hook for profiling the start of an op. * @@ -339,6 +375,51 @@ class TVM_DLL VirtualMachine : public runtime::ModuleNode { void SetInputTensorWithIndex(std::vector& tensors, // NOLINT(*) const TVMArgValue& tensor, int index, Device dev); + /*! + * \brief Convert tensor from TVMArgValue to ObjectRef. + * DLTensor and NDArray types are supported. + * \param tensor given arg value containing tensor. + * \return tensor in ObjectRef format + */ + ObjectRef TensorFromTVMArgValueToObjectRef(const TVMArgValue& tensor) const; + + /*! + * \brief Get index of outputs in register_file from func code + * \return result register index + */ + Index GetResultRegisterIndex() const; + + /*! + * \brief Calculate the index of operation which destination is result + * \param res_index is the index of op returning result + */ + void CalculatePreResultOpIndex(Index res_index); + + /*! + * \brief Get indices from register_file for output tensors. + * It helps to replace output tensors allocated in RunLoop by + * tensors pre-allocated outside. Scenario is when `set_output` is used + * \return indices from register_file for output tensors. + */ + std::vector GetOutputTensorRegIndices(); + + /*! + * \brief Write new allocated tensor to register_file of frame. + * \param instr current instruction containing shape and storage info. + */ + void WriteAllocatedTensor(const Instruction& instr); + + /*! + * \brief 'set_outputs_enabled' is assumed true for using this method. + * It is expected that result register has already contained tensor from outside, + * new memory is not allocated and write, but expected shape and data type are checked. + * For other register WriteAllocatedTensor method is used. + * \param instr current instruction containing shape and storage info. + */ + void WriteAllocatedTensorFromOutside(const Instruction& instr); + + bool FindIndex(const std::vector& indices, Index val) const; + protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs_; @@ -356,6 +437,14 @@ class TVM_DLL VirtualMachine : public runtime::ModuleNode { ObjectPtr exec_; /*! \brief The function name to inputs mapping. */ std::unordered_map> inputs_; + /*! \brief The function name to flag enabling scenario with set outputs. */ + std::unordered_map set_outputs_enabled_; + /*! \brief The index of operation which destination is result. */ + Index preresult_op_index_ = -1; + /*! \brief The function name to indices of output tensors in register file. */ + std::unordered_map> output_tensor_reg_indices_; + /*! \brief The function name to pre-allocated outputs mapping. */ + std::unordered_map> outputs_; /*! * \brief The "physical" devices the VM can execute primitives on. All "device indexes" * are w.r.t. this vector. Each entry in this vector must match the corresponding entry diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 615f66fdcc1c..20778c40fd51 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -399,6 +399,7 @@ def __init__(self, exe, device, memory_cfg=None): self._get_input_index = self.module["get_input_index"] self._set_input = self.module["set_input"] self._set_one_input = self.module["set_one_input"] + self._set_outputs = self.module["set_outputs"] self._setup_device(device, memory_cfg) def _setup_device(self, dev, memory_cfg): @@ -560,6 +561,41 @@ def invoke_stateful(self, func_name, *args, **kwargs): self.set_input(func_name, *args, **kwargs) self._invoke_stateful(func_name) + def invoke_with_outputs(self, func_name, input_args, output_args): + # TODO(vvchernov): consider scenario then output tensors set once + """Invoke a function with pre-allocated output tensors. + The output tensors should be set every invocation. + input_args can be None if set_input method was used before. + + This invoke method allows to avoid excess copying if memory for output tensors + was allocated before inference. + + Parameters + ---------- + func_name : str + The name of the function. + + input_args: dict of str to tvm.runtime.NDArray or np.ndarray + Named arguments to the function. + + output_args : list[tvm.runtime.NDArray] or list[DLTensor] + The output tensors of the function. + """ + if input_args: + func_params = self._exec.get_function_params(func_name) + new_args = [None] * len(func_params) + cnt = 0 + for k in input_args: + if k in func_params: + idx = func_params.index(k) + new_args[idx] = input_args[k] + cnt += 1 + assert cnt == len(func_params) + cargs = convert(new_args) + self._set_input(func_name, *cargs) + self._set_outputs(func_name, *output_args) + self._invoke(func_name) + def get_outputs(self): """Get the outputs from a call to :py:func`invoke_stateful`. diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 6f52f4b83c81..aaf4675733a8 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -143,8 +143,16 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } else { auto it = inputs_.find(func_name); ICHECK(it != inputs_.end()) << "Input has not been set for function " << func_name; - const std::vector& func_args = it->second; - *rv = Invoke(func, func_args); + const std::vector& input_args = it->second; + if (set_outputs_enabled_.count(func_name) && set_outputs_enabled_[func_name]) { + ICHECK(outputs_.count(func_name)) + << "Outputs have not been set for function " << func_name; + *rv = Invoke(func, input_args, outputs_[func_name]); + outputs_[func_name].clear(); + set_outputs_enabled_[func_name] = false; + } else { + *rv = Invoke(func, input_args); + } } }); } else if (name == "invoke_stateful") { @@ -224,6 +232,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, << "(func_name, index or name, tensor)"; SetOneInput(args[0], args[1], args[2]); }); + } else if (name == "set_outputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetOutputs(args[0], args); }); } else if (name == "load_late_bound_consts") { return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.size(), 1); @@ -272,6 +283,62 @@ void VirtualMachine::SetOneInput(std::string func_name, const TVMArgValue& tag, SetInputTensorWithIndex(inputs_[func_name], tensor, inp_index, dev); } +void VirtualMachine::SetOutputs(std::string func_name, TVMArgs args) { + set_outputs_enabled_[func_name] = true; + size_t outputs_size = args.size(); + // First args is func_name + ICHECK_GT(outputs_size, 1) << "There is no output arguments set"; + + std::vector func_args(outputs_size - 1); + for (size_t i = 1; i < outputs_size; ++i) { + // TODO(vvchernov): device? + func_args[i - 1] = TensorFromTVMArgValueToObjectRef(args[i]); + } + outputs_.erase(func_name); + outputs_.emplace(func_name, func_args); +} + +void VirtualMachine::PrintInfoAndSetInputArgs(const VMFunction& func, + const std::vector& args) { + VLOG(2) << "Executing Function: " << std::endl << func; + for (int i = 0; i < static_cast(devices_.size()); ++i) { + VLOG(2) << "Device " << i << " has device type " << devices_[i].device_type << " and device id " + << devices_[i].device_id + << (i == exec_->host_device_index ? " (using as host device)" : ""); + } + + InvokeGlobal(func, args); +} + +void VirtualMachine::SetOutputTensorsToRegister(const std::string& func_name, + const std::vector& outputs) { + size_t size = outputs.size(); + + if (output_tensor_reg_indices_[func_name].empty()) { + output_tensor_reg_indices_[func_name] = GetOutputTensorRegIndices(); + } + auto& reg_indices = output_tensor_reg_indices_[func_name]; + ICHECK_EQ(reg_indices.size(), size) + << "Number of outside output tensors should be equal to model outputs number"; + size_t i = 0; + for (auto it = reg_indices.begin(); it != reg_indices.end(); ++it, ++i) { + WriteRegister(*it, outputs[i]); + } +} + +ObjectRef VirtualMachine::TensorFromTVMArgValueToObjectRef(const TVMArgValue& output_tensor) const { + if (output_tensor.type_code() == kTVMDLTensorHandle) { + DLTensor* dl_tensor = output_tensor; + return NDArray::FromExternalDLTensor(*dl_tensor); + } else if (output_tensor.type_code() == kTVMNDArrayHandle) { + return output_tensor.AsObjectRef(); + } else { + LOG(FATAL) << "It supports tensor of DLTensor or NDArray type only! Given type is " + << output_tensor.type_code(); + } + return ObjectRef(); +} + int64_t VirtualMachine::GetInputIndexFromVMFunction(const std::string& func_name, const std::string& input_name) const { const auto& vm_func = CheckAndGetVMFunction(func_name); @@ -359,14 +426,7 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector& args) { - VLOG(2) << "Executing Function: " << std::endl << func; - for (int i = 0; i < static_cast(devices_.size()); ++i) { - VLOG(2) << "Device " << i << " has device type " << devices_[i].device_type << " and device id " - << devices_[i].device_id - << (i == exec_->host_device_index ? " (using as host device)" : ""); - } - - InvokeGlobal(func, args); + PrintInfoAndSetInputArgs(func, args); RunLoop(); return return_register_; } @@ -380,6 +440,14 @@ ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vectorfunctions[func_index], args); } +ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector& input_args, + const std::vector& output_args) { + PrintInfoAndSetInputArgs(func, input_args); + SetOutputTensorsToRegister(func.name, output_args); + RunLoop(output_tensor_reg_indices_[func.name]); + return return_register_; +} + void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { size_t arity = 0; @@ -518,7 +586,45 @@ int64_t VirtualMachine::LoadScalarInt(Index r) const { return result; } -void VirtualMachine::RunLoop() { +Index VirtualMachine::GetResultRegisterIndex() const { + Index op_index = 0; + while (code_[op_index].op != Opcode::Ret) { + ++op_index; + } + + return code_[op_index].result; +} + +void VirtualMachine::CalculatePreResultOpIndex(Index res_index) { + if (preresult_op_index_ == -1) { + preresult_op_index_ = 0; + while (code_[preresult_op_index_].dst != res_index) { + ++preresult_op_index_; + } + } +} + +std::vector VirtualMachine::GetOutputTensorRegIndices() { + std::vector reg_indices; + Index res_index = GetResultRegisterIndex(); + CalculatePreResultOpIndex(res_index); + auto& preres_instr = code_[preresult_op_index_]; + auto op_code = preres_instr.op; + if (op_code == Opcode::AllocTensor) { + reg_indices.emplace_back(res_index); + } else if (op_code == Opcode::AllocADT) { + for (Index i = 0; i < preres_instr.num_fields; ++i) { + reg_indices.push_back(preres_instr.datatype_fields[i]); + } + } else if (op_code == Opcode::ReshapeTensor) { + reg_indices.push_back(preres_instr.reshape_tensor.tensor); + } else { + LOG(FATAL) << "Operation " << size_t(op_code) << " is not supported for set_outputs method"; + } + return reg_indices; +} + +void VirtualMachine::RunLoop(const std::vector& output_tensor_reg_indices) { ICHECK(this->exec_); ICHECK(this->code_); pc_ = 0; @@ -666,21 +772,11 @@ void VirtualMachine::RunLoop() { } case Opcode::AllocTensor: { OpStartHook(instr); - auto shape = std::vector(instr.alloc_tensor.ndim); - - for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { - shape[i] = instr.alloc_tensor.shape[i]; + if (!output_tensor_reg_indices.empty() && FindIndex(output_tensor_reg_indices, instr.dst)) { + WriteAllocatedTensorFromOutside(instr); + } else { + WriteAllocatedTensor(instr); } - - auto storage_obj = ReadRegister(instr.alloc_tensor.storage); - auto offset = LoadScalarInt(instr.alloc_tensor.offset); - auto storage = Downcast(storage_obj); - auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor.dtype); - VLOG(2) << "allocated " - << RuntimeObject2String(obj, GetDevice(exec_->host_device_index), - /*show_contents=*/false); - - WriteRegister(instr.dst, obj); OpStopHook(); pc_++; goto main_loop; @@ -825,6 +921,75 @@ void VirtualMachine::RunLoop() { } } +void VirtualMachine::WriteAllocatedTensor(const Instruction& instr) { + auto shape = std::vector(instr.alloc_tensor.ndim); + + for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { + shape[i] = instr.alloc_tensor.shape[i]; + } + + auto storage_obj = ReadRegister(instr.alloc_tensor.storage); + auto offset = LoadScalarInt(instr.alloc_tensor.offset); + auto storage = Downcast(storage_obj); + auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor.dtype); + VLOG(2) << "allocated " + << RuntimeObject2String(obj, GetDevice(exec_->host_device_index), + /*show_contents=*/false); + + WriteRegister(instr.dst, obj); +} + +void VirtualMachine::WriteAllocatedTensorFromOutside(const Instruction& instr) { + // External tensor(s) has been already written to the register (instr.dst) + auto ex_arr = Downcast(ReadRegister(instr.dst)); + auto ex_shape = ex_arr.Shape(); + auto ex_size = ex_shape.size(); + auto ex_dtype = ex_arr->dtype; + + auto in_size = instr.alloc_tensor.ndim; + auto in_dtype = instr.alloc_tensor.dtype; + ICHECK_EQ(TypeEqual(in_dtype, ex_dtype), true) + << "Data types mismatching for internal and external output tensors"; + + bool size_check = false; + if (ex_size != in_size) { + size_check = true; + } else { + for (size_t i = 0; i < in_size; ++i) { + if (ex_shape[i] != instr.alloc_tensor.shape[i]) { + size_check = true; + break; + } + } + } + + if (size_check) { + // Match element number + size_t in_el_num = 1, ex_el_num = 1; + for (size_t i = 0; i < ex_size; ++i) { + ex_el_num *= ex_shape[i]; + } + for (size_t i = 0; i < in_size; ++i) { + in_el_num *= instr.alloc_tensor.shape[i]; + } + ICHECK_EQ(in_el_num, ex_el_num) + << "Element number mismatching of internal and external output tensors"; + if (code_[preresult_op_index_].op == Opcode::ReshapeTensor) { + int64_t* dims = instr.alloc_tensor.shape; + std::vector ref_shape(dims, dims + int64_t(in_size)); + auto reshaped_tensor = ex_arr.CreateView(ref_shape, ex_dtype); + WriteRegister(instr.dst, reshaped_tensor); + } else { + LOG(FATAL) << "Internal and external output tensor shapes are mismatched"; + } + } +} + +bool VirtualMachine::FindIndex(const std::vector& indices, Index val) const { + auto it = std::find(indices.begin(), indices.end(), val); + return it != indices.end(); +} + runtime::Module CreateVirtualMachine(Executable* exec) { auto vm = make_object(); vm->LoadExecutable(GetObjectPtr(exec)); diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 0b62db85c904..45e305c9a195 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -846,26 +846,38 @@ def relay_ext_test(func): assert "shape_func" in opt_mod.astext(False) -def test_vm_rpc(): +def prepare_vm_model(path, tensor_shape): """ - This test checks to make sure you can export a VMExecutable, - upload it to a remote machine using RPC and then execute it - on the other machine. + Virtual Machine is compiled for simple topology and + exported as library to given path """ target = tvm.target.Target("llvm --host=llvm") # Build a IRModule. - x = relay.var("x", shape=(10, 1)) + x = relay.var("x", shape=tensor_shape) f = relay.Function([x], x + x) mod = IRModule.from_expr(f) # Compile to VMExecutable. vm_exec = vm.compile(mod, target=target) + # Export to Disk + vm_exec.mod.export_library(path) + + +def test_vm_rpc(): + """ + This test checks to make sure you can export a VMExecutable, + upload it to a remote machine using RPC and then execute it + on the other machine. + """ + # Shape for input and output tensors + shape = (10, 1) + # Export to Disk temp = utils.tempdir() path = temp.relpath("vm_library.so") - vm_exec.mod.export_library(path) + prepare_vm_model(path, shape) # Use local rpc server for testing. # Server must use popen so it doesn't inherit the current process state. It @@ -881,7 +893,7 @@ def check_remote(server): device = remote.cpu() # Build a VM out of the executable and context. vm_factory = runtime.vm.VirtualMachine(rexec, device) - np_input = np.random.uniform(size=(10, 1)).astype("float32") + np_input = np.random.uniform(size=shape).astype("float32") input_tensor = tvm.nd.array(np_input, device) # Invoke its "main" function. out = vm_factory.invoke("main", input_tensor) @@ -891,6 +903,72 @@ def check_remote(server): check_remote(rpc.Server("127.0.0.1")) +def test_vm_invoke_with_outputs_rpc(): + """ + This test checks to make sure you can export a VMExecutable, + upload it to a remote machine using RPC and then execute it + on the other machine with preallocated outputs. + """ + # Shape for input and output tensors + shape = (3, 2) + + # Export to Disk + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + prepare_vm_model(path, shape) + + # Use local rpc server for testing. + # Server must use popen so it doesn't inherit the current process state. It + # will crash otherwise. + def check_remote_invoke_with_outputs(server): + remote = rpc.connect(server.host, server.port, session_timeout=10) + + # Upload the serialized Executable. + remote.upload(path) + # Get a handle to remote Executable. + rexec = remote.load_module("vm_library.so") + + device = remote.cpu() + # Build a VM out of the executable and context. + vm_factory = runtime.vm.VirtualMachine(rexec, device) + np_input = np.random.uniform(size=shape).astype("float32") + input_tensor = tvm.nd.array(np_input, device) + np_output = np.empty(shape, dtype="float32") + output_tensor = tvm.nd.array(np_output, device) + # Invoke its "main" function. + vm_factory.invoke_with_outputs( + "main", input_args={"x": input_tensor}, output_args=[output_tensor] + ) + # Check the result. + np.testing.assert_allclose(output_tensor.numpy(), np_input + np_input) + + check_remote_invoke_with_outputs(rpc.Server("127.0.0.1")) + + +def test_vm_invoke_with_outputs(): + target = tvm.target.Target("llvm") + shape = (3, 2) + + # Build a IRModule. + x = relay.var("x", shape=shape) + f = relay.Function([x], x + x) + mod = IRModule.from_expr(f) + + # Compile to VMExecutable. + vm_exec = vm.compile(mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu()) + np_input = np.random.uniform(size=shape).astype("float32") + input_tensor = tvm.nd.array(np_input) + np_output = np.empty(shape, dtype="float32") + output_tensor = tvm.nd.array(np_output) + # Invoke + vm_factory.invoke_with_outputs( + "main", input_args={"x": input_tensor}, output_args=[output_tensor] + ) + # Check the result. + np.testing.assert_allclose(output_tensor.numpy(), np_input + np_input) + + def test_get_output_single(): target = tvm.target.Target("llvm")