Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions apps/hexagon_launcher/launcher_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,13 @@ const tvm::runtime::PackedFunc get_module_func(tvm::runtime::Module module,
}

void reset_device_api() {
const tvm::runtime::PackedFunc api = get_runtime_func("device_api.cpu");
const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon.v2");
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(api);
}

tvm::runtime::Module load_module(const std::string& file_name) {
static const tvm::runtime::PackedFunc loader = get_runtime_func("runtime.module.loadfile_so");
static const tvm::runtime::PackedFunc loader =
get_runtime_func("runtime.module.loadfile_hexagon");
tvm::runtime::TVMRetValue rv = loader(file_name);
if (rv.type_code() == kTVMModuleHandle) {
return rv.operator tvm::runtime::Module();
Expand All @@ -169,7 +170,10 @@ tvm::runtime::Module create_graph_executor(const std::string& graph_json,
uint64_t device_type = device.device_type;
uint64_t device_id = device.device_id;

std::string linked_params = "tvm.runtime.hexagon.lookup_linked_params";
const tvm::runtime::PackedFunc lookup_linked_params = get_runtime_func(linked_params);
// Use default param lookup function (linked into the module).
tvm::runtime::TVMRetValue rv = create_executor(graph_json, graph_module, device_type, device_id);
tvm::runtime::TVMRetValue rv =
create_executor(graph_json, graph_module, lookup_linked_params, device_type, device_id);
return rv.operator tvm::runtime::Module();
}
2 changes: 2 additions & 0 deletions apps/hexagon_launcher/launcher_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ struct Model {

static tvm::Device device() { return tvm::Device{static_cast<DLDeviceType>(kDLHexagon), 0}; }

static tvm::Device external() { return tvm::Device{static_cast<DLDeviceType>(kDLCPU), 0}; }

tvm::runtime::PackedFunc run;
};

Expand Down
17 changes: 14 additions & 3 deletions apps/hexagon_launcher/launcher_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ extern "C" {
#include <qurt_hvx.h>
}

#include <tvm/runtime/object.h>

#include <algorithm>
#include <memory>
#include <string>
Expand Down Expand Up @@ -106,7 +108,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_set_input)(remote_handle64 handle, int inpu

DLTensor tensor{
const_cast<unsigned char*>(input_value),
Model::device(),
Model::external(),
meta->ndim,
meta->dtype,
const_cast<int64_t*>(meta->shape),
Expand Down Expand Up @@ -153,6 +155,16 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_output)(remote_handle64 handle, int out
tvm::runtime::PackedFunc get_output = get_module_func(TheModel->graph_executor, "get_output");
tvm::runtime::NDArray output = get_output(output_idx);

std::vector<int64_t> shape_vec{output->shape, output->shape + output->ndim};

auto* container = new tvm::runtime::NDArray::Container(
static_cast<void*>(output_value), shape_vec, output->dtype, Model::external());
container->SetDeleter([](tvm::Object* container) {
delete static_cast<tvm::runtime::NDArray::Container*>(container);
});

tvm::runtime::NDArray host_output(GetObjectPtr<tvm::Object>(container));

if (meta_size != 0) {
auto* meta = reinterpret_cast<tensor_meta*>(output_meta);
if (meta_size < meta->meta_size(output->ndim)) {
Expand All @@ -170,8 +182,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_output)(remote_handle64 handle, int out
return error_too_small(__func__, "value_size", value_size, data_size);
}

auto data = reinterpret_cast<decltype(output_value)>(output->data);
std::copy(data, data + data_size, output_value);
host_output.CopyFrom(output);
}

return AEE_SUCCESS;
Expand Down