Skip to content
7 changes: 6 additions & 1 deletion python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ def set_input(self, key=None, value=None, **params):
keys = list(params.keys())
keys.sort(key=lambda x: -np.prod(params[x].shape))
for k in keys:
self._get_input(k).copyfrom(params[k])
# TODO(zhiics) Skip the weights for submodule in a better way.
# We should use MetadataModule for initialization and remove
# params from set_input
val = self._get_input(k)
if val:
self._get_input(k).copyfrom(params[k])

def run(self, **input_dict):
"""Run forward execution of the graph
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def save(self):

import numpy as np
import tvm
from tvm import te
from tvm import te
from tvm import relay
# define a simple network.
x = relay.var('x', shape=(10, 10))
Expand Down Expand Up @@ -309,12 +309,17 @@ def set_input(self, func_name, *args, **kwargs):
Named arguments to the function.
"""
if kwargs:
# kwargs is a super set of the required function parameters. We
# only find the ones that are needed.
func_params = self._exec.get_function_params(func_name)
new_args = [None] * len(func_params)
assert len(args) + len(kwargs) == len(func_params)
cnt = 0
for k in kwargs:
idx = func_params.index(k)
new_args[idx] = kwargs[k]
if k in func_params:
idx = func_params.index(k)
new_args[idx] = kwargs[k]
cnt += 1
assert len(args) + cnt == len(func_params)
idx = 0
for i, arg in enumerate(new_args):
if arg is None:
Expand Down
7 changes: 5 additions & 2 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,11 @@ class RelayBuildModule : public runtime::ModuleNode {
}

Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
// Import all external runtime modules.
for (const auto& it : ext_mods) ret_.mod.Import(it);
// TODO(zhiics) We should be able to completely switch to MetadataModule no
// matter whether there are external modules or not.
if (!ext_mods.empty()) {
ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods);
}
}

private:
Expand Down
38 changes: 21 additions & 17 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@ class CompileEngineImpl : public CompileEngineNode {
}

Array<tvm::runtime::Module> LowerExternalFunctions() {
std::unordered_map<std::string, IRModule> ext_mods;
Array<tvm::runtime::Module> ret;
std::unordered_map<std::string, std::string> cached_symbol;
std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
Expand All @@ -581,29 +582,31 @@ class CompileEngineImpl : public CompileEngineNode {
auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set";
std::string code_gen_name = code_gen.value();
if (ext_mods.find(code_gen_name) == ext_mods.end()) {
ext_mods[code_gen_name] = IRModule({}, {});
}
cached_ext_funcs.push_back(it.first);

auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
auto gv = GlobalVar(symbol_name.value());

std::string sn = symbol_name.value();
if (cached_symbol.count(sn)) {
cached_symbol[sn] = code_gen_name;
} else {
CHECK_NE(sn, code_gen_name)
<< "Found duplicated symbol: " << sn << " for: " << code_gen_name;
}

std::string ext_name = "relay.ext." + code_gen_name;
auto pf = tvm::runtime::Registry::Get(ext_name);
CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
// No need to keep compiler attribute at this point, functions have been
// extracted for specific codegen.
src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
ext_mods[code_gen_name]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
}
runtime::Module ext_mod = (*pf)(src_func);

Array<tvm::runtime::Module> ret;
for (const auto& it : ext_mods) {
std::string ext_name = "relay.ext." + it.first;
auto pf = tvm::runtime::Registry::Get(ext_name);
CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
runtime::Module ext_mod = (*pf)(it.second);
CHECK(ext_mod.defined()) << "No external runtime is generated.";
ret.push_back(ext_mod);
CHECK(ext_mod.defined()) << "No external runtime is generated.";
ret.push_back(ext_mod);
}
}

// No need to cache external functions as we collected them all to create
Expand Down Expand Up @@ -659,6 +662,7 @@ class CompileEngineImpl : public CompileEngineNode {
CHECK(name_node.defined()) << "External function has not been attached a name yet.";
cache_node->func_name = std::string(name_node.value());
cache_node->target = tvm::target::ext_dev();
cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func);
value->cached_func = CachedFunc(cache_node);
return value;
}
Expand Down
85 changes: 39 additions & 46 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <fstream>
#include <sstream>
#include <string>

#include "../../utils.h"
#include "codegen_c.h"
Expand Down Expand Up @@ -76,43 +77,29 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
}

std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
// Note this is for demonstration purpose. ConstantNode doesn't necessarily
// belong to calls. We need to revisit this when tuples come into play.

std::ostringstream decl_stream;
std::ostringstream buf_stream;

Output output;
output.name = "const_" + std::to_string(const_idx_++);

runtime::NDArray array = cn->data;
const auto& shape = array.Shape();

// Get the number of elements.
int64_t num_elems = 1;
for (auto i : shape) num_elems *= i;

// Get const: static_cast<float*>(gcc_0_consts[0]->data)
output.name = CreateDataReference(ext_func_id_, const_idx_);
const auto* type_node = cn->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
// Define a const buffer: float const_0[64] = {1.0, 2.0, ...};
//
// Technically, you may need: static float* const_0 = (float*)malloc(4 * 64)
// to avoid possible stack overflow.
buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {";
if (dtype == "float") {
float* p_flt = static_cast<float*>(array->data);
for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
if (num_elems) buf_stream << p_flt[num_elems - 1];
} else if (dtype == "int") {
int* p_flt = static_cast<int*>(array->data);
for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
if (num_elems) buf_stream << p_flt[num_elems - 1];
} else {
LOG(FATAL) << "Only float and int are supported for now.";

// Generate the global variable for needed ndarrays
if (const_array_name_.empty()) {
const_array_name_ = CreateNDArrayPool(ext_func_id_);
std::string checker = CreateInitChecker(ext_func_id_);
ext_func_body_.insert(ext_func_body_.begin(), checker);
}
buf_stream << "};";
ext_func_body.insert(ext_func_body.begin(), buf_stream.str());

CHECK(dtype == "float" || dtype == "int") << "Only float and int are supported for now.";
output.dtype = dtype;

std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_);
const_vars_.push_back(const_var_name);
const_idx_++;

return {output};
}
Expand Down Expand Up @@ -175,7 +162,7 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
buf_decl_.push_back(buf_stream.str());

decl_stream << ", " << out << ");";
ext_func_body.push_back(decl_stream.str());
ext_func_body_.push_back(decl_stream.str());

// Update output buffer
// Note C codegen only handles TensorType. Therefore, we don't flatten
Expand All @@ -198,7 +185,7 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
for (auto decl : func_decl_) {
code_stream_ << decl << "\n";
}
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out);
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out);
}

private:
Expand All @@ -213,16 +200,22 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
/*! \brief The arguments of a C compiler compatible function. */
Array<Var> ext_func_args_;
/*! \brief The statements of a C compiler compatible function. */
std::vector<std::string> ext_func_body;
std::vector<std::string> ext_func_body_;
/*! \brief The array declared to store the constant values. */
std::string const_array_name_;
/*! \brief The declaration statements of a C compiler compatible function. */
std::vector<std::string> func_decl_;
/*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_;
/*! \brief The variable name to constant mapping. */
Array<String> const_vars_;

friend class CSourceCodegen;
};

class CSourceCodegen : public CSourceModuleCodegenBase {
public:
void GenCFunc(const Function& func) {
std::pair<std::string, Array<String>> GenCFunc(const Function& func) {
CHECK(func.defined()) << "Input error: expect a Relay function.";

// Record the external symbol for runtime lookup.
Expand All @@ -231,14 +224,19 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
CodegenC builder(sid);
auto out = builder.VisitExpr(func->body);
code_stream_ << builder.JIT(out);

return {sid, builder.const_vars_};
}

runtime::Module CreateCSourceModule(const ObjectRef& ref) override {
// Create headers
code_stream_ << "#include <cstring>\n";
code_stream_ << "#include <vector>\n";
code_stream_ << "#include <tvm/runtime/c_runtime_api.h>\n";
code_stream_ << "#include <tvm/runtime/container.h>\n";
code_stream_ << "#include <tvm/runtime/packed_func.h>\n";
code_stream_ << "#include <dlpack/dlpack.h>\n";
code_stream_ << "using namespace tvm::runtime;\n";

// Append some common macro for operator definition.
const char* operator_macro = R"op_macro(
Expand All @@ -262,22 +260,17 @@ class CSourceCodegen : public CSourceModuleCodegenBase {

code_stream_ << operator_macro << "\n\n";

if (ref->IsInstance<FunctionNode>()) {
GenCFunc(Downcast<Function>(ref));
} else if (ref->IsInstance<IRModuleNode>()) {
IRModule mod = Downcast<IRModule>(ref);
for (const auto& it : mod->functions) {
GenCFunc(Downcast<Function>(it.second));
}
} else {
LOG(FATAL) << "The input ref is expected to be a Relay function or module"
<< "\n";
}
CHECK(ref->IsInstance<FunctionNode>());
auto res = GenCFunc(Downcast<Function>(ref));
std::string code = code_stream_.str();

String sym = std::get<0>(res);
Array<String> variables = std::get<1>(res);

// Create a CSourceModule
// Create a CSource module
const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
return (*pf)(code_stream_.str(), "cc");
return (*pf)(code, "c", sym, variables);
}

private:
Expand Down
Loading