From 0bb6fcaf71569c00402a7b2b0675247fab195c1c Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Sun, 28 Nov 2021 22:19:03 -0800 Subject: [PATCH 01/41] Move ShapeToJSON to utils. --- src/relay/backend/graph_executor_codegen.cc | 18 ++---------------- src/relay/backend/utils.cc | 9 +++++++++ src/relay/backend/utils.h | 8 ++++++++ 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index f61fe9b402b3..b07b3f8c6cd1 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -294,20 +294,6 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator - * - * \param shape - * \return std::vector - */ - std::vector _ShapeToJSON(tvm::Array shape) { - std::vector ret; - for (IndexExpr dim : shape) { - const int64_t* pval = tir::as_const_int(dim); - ret.push_back(*pval); - } - return ret; - } /*! * \brief Add node to graph @@ -352,7 +338,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfields.size(); ++i) { if (const auto* typ = tuple_type->fields[i].as()) { ret.push_back(GraphNodeRef(node_id, i)); - shape.emplace_back(_ShapeToJSON(typ->shape)); + shape.emplace_back(ShapeToJSON(typ->shape)); dtype.emplace_back(DType2String(typ->dtype)); } else { LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported"; @@ -369,7 +355,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator()) { ShapeVector shape; std::vector dtype; - shape.emplace_back(_ShapeToJSON(tensor_type->shape)); + shape.emplace_back(ShapeToJSON(tensor_type->shape)); dtype.emplace_back(DType2String(tensor_type->dtype)); node->attrs_["shape"] = shape; node->attrs_["dtype"] = dtype; diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 608d4cdb9f85..bfd2a9a8c559 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -275,6 +275,15 @@ void UpdateAutoSchedulerOpWeights(const IRModule& module) { (*te_compiler_update_weights)(weight_map); } +std::vector ShapeToJSON(tvm::Array shape) { + std::vector ret; + for (IndexExpr dim : shape) { + const int64_t* pval = tir::as_const_int(dim); + ret.push_back(*pval); + } + return ret; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 658283b5dc36..c8e3788d5129 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -527,6 +527,14 @@ Map TargetStrModuleMapToTargetModuleMap( */ void UpdateAutoSchedulerOpWeights(const IRModule& module); +/*! + * \brief Extract shape from expr to vector + * + * \param shape + * \return std::vector + */ +std::vector ShapeToJSON(tvm::Array shape); + } // namespace backend } // namespace relay } // namespace tvm From 5ccf3b068e5774f8ef93aa49d923e8006edbe8b2 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 29 Nov 2021 19:08:49 -0800 Subject: [PATCH 02/41] Return new Metadata from graph-level codegen. --- src/relay/backend/aot_executor_codegen.cc | 44 ++++++++++++++++++--- src/relay/backend/build_module.cc | 2 +- src/relay/backend/graph_executor_codegen.cc | 8 ++++ src/relay/backend/utils.h | 3 +- src/relay/backend/vm/compiler.cc | 3 +- 5 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index d901f8a26c4f..44e1e848795f 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -44,6 +44,7 @@ #include "../op/annotation/annotation.h" #include "../op/call/call.h" #include "../op/memory/device_copy.h" +#include "../../target/metadata.h" #include "../transforms/device_aware_visitors.h" #include "./name_transforms.h" #include "./te_compiler.h" @@ -68,6 +69,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { void Run(const Function& func) { VisitExpr(func); } std::vector GetReturnIds() const { return return_ids_; } + std::vector GetReturnTtypes() const { return return_ttypes_; } StorageMap GetStorageMap() const { return storage_device_map_; } @@ -175,6 +177,12 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { for (auto sid : sinfo->storage_ids) { return_ids_.push_back(sid); } + return_ttypes_.clear(); + auto ttypes = FlattenTupleType(e->checked_type()); + return_ttypes_.reserve(ttypes.size()); + for (auto ttype : ttypes) { + return_ttypes_.push_back(ttype); + } } } /*! @@ -250,6 +258,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { int next_available_sid_{0}; /*! \brief the set of intermediate tensors that are return variables */ std::vector return_ids_; + /*! \brief the data types of the return values */ + std::vector return_ttypes_; }; /*! \brief Code generator for AOT executor */ @@ -867,12 +877,34 @@ class AOTExecutorCodegen : public MixedModeVisitor { ret.lowered_funcs.Set(target_host_, mod_run); } - std::vector input_var_names(input_vars_.size()); - std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(), - [](Var input_var) -> String { return input_var->name_hint(); }); - ret.metadata = - runtime::Metadata(input_var_names, ListDevices(), return_sid_.size(), - runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_); + std::vector inputs; + for (auto v : input_vars_) { + auto ttype = Downcast(v->type_annotation); + inputs.push_back( + runtime::metadata::TensorInfo( + make_object( + v->name_hint(), ShapeToJSON(ttype->shape), ttype->dtype))); + } + + std::vector outputs; + auto output_ttypes = final_aot_allocator.GetReturnTtypes(); + for (unsigned int i = 0; i < output_ttypes.size(); i++) { + auto ttype = Downcast(output_ttypes[i]); + std::stringstream name; + name << "output" << i; + outputs.push_back( + runtime::metadata::TensorInfo( + make_object( + name.str(), ShapeToJSON(ttype->shape), ttype->dtype))); + } + auto devices = ListDevices(); + std::vector devices_vector; + for (auto d : devices) { + devices_vector.push_back(d.operator std::string()); + } + auto n = make_object( + kMetadataVersion, inputs, outputs, devices_vector, runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_); + ret.metadata = runtime::metadata::Metadata(std::move(n)); return ret; } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ccfd30476f67..c7e9ee779dbf 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -103,7 +103,7 @@ struct ExecutorCodegen { Array ListDevices() { return CallFunc>("get_devices"); } - runtime::Metadata GetMetadata() { return CallFunc("get_metadata"); } + runtime::metadata::Metadata GetMetadata() { return CallFunc("get_metadata"); } virtual ~ExecutorCodegen() {} protected: diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index b07b3f8c6cd1..59d3786951c8 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -290,6 +290,14 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator inputs; + std::vector outputs; + std::vector devices_vector; + auto n = make_object( + kMetadataVersion, inputs, outputs, devices_vector, runtime::kTvmExecutorGraph, mod_name_, + "packed", Bool(false)); + ret.metadata = runtime::metadata::Metadata(std::move(n)); return ret; } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index c8e3788d5129..8f430d292e07 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -147,7 +148,7 @@ struct LoweredOutput { Array external_mods; Map function_metadata; std::unordered_map> params; - runtime::Metadata metadata; + runtime::metadata::Metadata metadata; // points to InMemoryMetadataNode }; /*! diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 73f4b672a81c..6af7991f045d 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -46,6 +46,7 @@ #include #include +#include "../../../target/metadata.h" #include "../../../target/metadata_module.h" #include "../../../target/source/codegen_source_base.h" #include "../../op/annotation/annotation.h" @@ -1162,7 +1163,7 @@ void VMCompiler::Codegen() { } lib = codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target, - Runtime::Create("cpp"), runtime::Metadata()); + Runtime::Create("cpp"), runtime::metadata::Metadata(make_object())); exec_->SetLib(lib); } From 55ca67c41a037ab2fe1dda1dafc4cec6c6de9be5 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 29 Nov 2021 19:13:29 -0800 Subject: [PATCH 03/41] Stack-allocate DLTensor instances when necessary. --- src/relay/backend/aot_executor_codegen.cc | 120 +++++++++++++++++++--- 1 file changed, 106 insertions(+), 14 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 44e1e848795f..59ceab4fe4c3 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -325,6 +326,43 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } + PrimExpr MakeDLTensor(Expr relay_arg, TensorType ttype, PrimExpr data) { + for (Var v : input_vars_) { + if (v == relay_arg) { + return data; + } + } + for (int return_sid : return_sid_) { + auto return_expr = sids_table_[return_sid]; + if (return_expr == relay_arg) { + return data; + } + } + return data; /*tvm::tir::Call( + DataType::Handle(), + tvm::tir::builtin::tvm_stack_make_array(), + Array({data, tvm::tir::Call(DataType::Handle(), + tvm::tir::builtin::tvm_stack_make_shape(), + {ttype->shape}), + tvm::Integer(0), + tvm::Integer(ttype->shape.size()), + tvm::tir::make_const(ttype->dtype, 0), + tvm::Integer(0)})); */ + } + + void PushTuple(Tuple tuple, std::vector sids, Array args) { + CHECK_EQ(sids.size(), tuple->fields.size()) + << "Relay tuple does not map 1:1 into TIR; AOT can't handle this type of Relay Expr in a CallNode."; + StorageInfo& sinfo = storage_device_map_[tuple]; + for (unsigned int i = 0; i < sids.size(); ++i) { + if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[i]) != return_sid_.end()) { + args.push_back(sids[i]); + } else { + args.push_back(MakeDLTensor(tuple->fields[i], Downcast(tuple->fields[i]->checked_type()), sids[i])); + } + } + } + /*! * brief Create a function call * \param call_lowered_props The lowered function and the arguments to call it with @@ -339,32 +377,53 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Pack the inputs for (const Expr& arg : call_lowered_props.arguments) { if (params_by_expr_.find(arg) != params_by_expr_.end()) { - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[arg])}); - args.push_back(param_handle); + args.push_back(MakeDLTensor(arg, Downcast(arg->checked_type()), + tir::Cast(runtime::DataType(DataType::TypeCode::kHandle, 32, 1), + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[arg])})))); } else { - auto var_arg = FindExpr(arg); - for (const auto& var : var_arg) { - args.push_back(var); + auto sids = FindExpr(arg); + if (sids.size() > 1) { + auto tuple = Downcast(arg); + PushTuple(tuple, sids, args); + } else { + StorageInfo& sinfo = storage_device_map_[arg]; + if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]) != return_sid_.end()) { + args.push_back(sids[0]); + } else { + args.push_back(MakeDLTensor(arg, Downcast(arg->checked_type()), sids[0])); + } } } } // Pack the return(s) value. A call node can produce multiple outputs - for (const auto& var : PackSid(result_expr)) { - args.push_back(var); + auto result_expr_sid = PackSid(result_expr); + if (result_expr_sid.size() > 1) { + auto tuple = Downcast(result_expr); + PushTuple(tuple, result_expr_sid, args); + } else { + StorageInfo& sinfo = storage_device_map_[result_expr]; + if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]) != return_sid_.end()) { + args.push_back(result_expr_sid[0]); + } else { + args.push_back(MakeDLTensor(result_expr, Downcast(result_expr->checked_type()), result_expr_sid[0])); + } } - // Use tvm_call_packed to execute the function unless we're calling directly - auto calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); + // Choose call style based on Runtime/Executor config. + Op calling_pattern; if (use_unpacked_api_) { calling_pattern = tvm::tir::builtin::call_extern(); + } else if (use_call_cpacked_) { + calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); + } else { + calling_pattern = tvm::tir::builtin::tvm_call_packed(); } GlobalVar global_var = call_lowered_props.lowered_func; tir::Var empty_var("no_device_context", DataType::Handle()); bool has_c_device_api_context = device_contexts_.count(global_var) != 0; - bool use_cpacked_api = !use_unpacked_api_; // The device context is passed to the operator in one of the following calling patterns: // * Unpacked / direct function call with context: @@ -388,7 +447,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { func_call, GenerateDeviceHook(context, "Close"), })); - } else if (use_cpacked_api) { + } else if (use_call_cpacked_) { // call_cpacked calling convention needs a blank context args.push_back(tir::make_zero(DataType::Handle())); tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); @@ -698,13 +757,25 @@ class AOTExecutorCodegen : public MixedModeVisitor { Target target_host_; /*! * \brief unpacked api toggle - * When set to true the code generated will use unpacked calls to functions: + * When set to true, the generated code will use unpacked calls to functions: * func(void* arg0, void* arg1) * Rather than packed calls: * func(void* args) * Defaults to using the packed calling convention */ Bool use_unpacked_api_; + /*! + * \brief cpacked api toggle + * When set to true, the generated code will use call_cpacked to call functions directly, assuming + * they exist in a DSO-exportable module. + * func(...) + * Rather than through the traditional call_packed calls, which should use function pointers + * looked-up through TVMBackendGetFuncFromEnv: + * TVMBackendPackedCFunc* func_ptr = TVMBackendGetFuncFromEnv("func"); + * func_ptr(...) + * Defaults to using the packed calling convention + */ + Bool use_call_cpacked_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). @@ -731,7 +802,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { public: AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) - : mod_(mod), targets_(targets), target_host_(target_host), use_unpacked_api_(Bool(false)) {} + : mod_(mod), targets_(targets), target_host_(target_host), use_unpacked_api_(Bool(false)), + use_call_cpacked_(Bool(false)) {} LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { VLOG_CONTEXT << "AOT"; @@ -741,11 +813,29 @@ class AOTExecutorCodegen : public MixedModeVisitor { ICHECK(target_host_.defined()) << "require a target_host to be given for AOT codegen"; VLOG(1) << "target host: " << target_host_->ToDebugString(); + Runtime runtime_config = mod->GetAttr(tvm::attr::kRuntime).value(); Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); String interface_api = executor_config->GetAttr("interface-api").value_or("packed"); Integer workspace_byte_alignment = executor_config->GetAttr("workspace-byte-alignment").value_or(16); use_unpacked_api_ = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); + use_call_cpacked_ = Bool(interface_api == "c"); + + // Validate choice of use_unpacked_api_ and use_call_cpacked_ + if (runtime_config->name == kTvmRuntimeCrt) { + CHECK(interface_api == "c" || bool(use_unpacked_api_) == false) + << "Either need interface_api == \"c\" (got: " << interface_api + << ") or unpacked-api == false (got: " << use_unpacked_api_ + << ") when targeting c runtime"; + } else if (runtime_config->name == kTvmRuntimeCpp) { + CHECK(bool(use_unpacked_api_) == false && bool(use_call_cpacked_) == true) + << "Need unpacked-api == false (got: " << use_unpacked_api_ + << ") and interface-api == \"c\" (got: " << interface_api + << ") when targeting c++ runtime"; + } else { + ICHECK(false) << "runtime_config (" << runtime_config->name + << ") is not one of the expected values"; + } // TODO(mbs): Plumb from compiler config VirtualDevice host_virtual_device = VirtualDevice::ForTarget(target_host_); @@ -886,6 +976,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { v->name_hint(), ShapeToJSON(ttype->shape), ttype->dtype))); } + LOG(INFO) << "MAKE METADATA? "; std::vector outputs; auto output_ttypes = final_aot_allocator.GetReturnTtypes(); for (unsigned int i = 0; i < output_ttypes.size(); i++) { @@ -905,6 +996,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { auto n = make_object( kMetadataVersion, inputs, outputs, devices_vector, runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_); ret.metadata = runtime::metadata::Metadata(std::move(n)); + LOG(INFO) << "MAKE METADATA: " << ret.metadata; return ret; } From 027078d36044b027c46c07d537ef8ba0d9d141da Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Sun, 28 Nov 2021 22:33:13 -0800 Subject: [PATCH 04/41] Rename MetadataModule to ConstLoaderModule. --- python/tvm/contrib/graph_executor.py | 2 +- ...adata_module.cc => const_loader_module.cc} | 70 ++++++++++--------- src/runtime/const_loader_module.h | 52 ++++++++++++++ src/runtime/meta_data.h | 13 ---- src/target/metadata_module.cc | 22 +++--- 5 files changed, 100 insertions(+), 59 deletions(-) rename src/runtime/{metadata_module.cc => const_loader_module.cc} (73%) create mode 100644 src/runtime/const_loader_module.h diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 6337c6e6fec5..5c97e4c33b50 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -188,7 +188,7 @@ def set_input(self, key=None, value=None, **params): keys.sort(key=lambda x: -np.prod(params[x].shape)) for k in keys: # TODO(zhiics) Skip the weights for submodule in a better way. - # We should use MetadataModule for initialization and remove + # We should use ConstLoaderModule for initialization and remove # params from set_input val = self._get_input(k) if val: diff --git a/src/runtime/metadata_module.cc b/src/runtime/const_loader_module.cc similarity index 73% rename from src/runtime/metadata_module.cc rename to src/runtime/const_loader_module.cc index 7cb986bba62c..fd890e40cd0b 100644 --- a/src/runtime/metadata_module.cc +++ b/src/runtime/const_loader_module.cc @@ -43,17 +43,17 @@ namespace runtime { /*! * \brief The metadata module is designed to manage initialization of the - * imported submodules. + * imported submodules for the C++ runtime. */ -class MetadataModuleNode : public ModuleNode { +class ConstLoaderModuleNode : public ModuleNode { public: - MetadataModuleNode(const std::unordered_map& metadata, - const std::unordered_map>& sym_vars) - : metadata_(metadata), sym_vars_(sym_vars) { + ConstLoaderModuleNode(const std::unordered_map& const_var_ndarray, + const std::unordered_map>& const_vars_by_symbol) + : const_var_ndarray_(const_var_ndarray), const_vars_by_symbol_(const_vars_by_symbol) { // Only the related submodules are cached to reduce the number of runtime // symbol lookup for initialization. Otherwise, symbols/primitives in the // DSO module will also be cached but they never need to be initialized. - for (const auto& it : sym_vars_) { + for (const auto& it : const_vars_by_symbol_) { initialized_[it.first] = false; } } @@ -78,7 +78,7 @@ class MetadataModuleNode : public ModuleNode { return PackedFunc(nullptr); } - const char* type_key() const { return "metadata"; } + const char* type_key() const { return "const_loader"; } /*! * \brief Get the list of metadata that is required by the given module. @@ -87,11 +87,11 @@ class MetadataModuleNode : public ModuleNode { */ Array GetRequiredMetadata(const std::string& symbol) { Array ret; - ICHECK_GT(sym_vars_.count(symbol), 0U) << "No symbol is recorded for " << symbol; - std::vector vars = sym_vars_[symbol]; + ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) << "No symbol is recorded for " << symbol; + std::vector vars = const_vars_by_symbol_[symbol]; for (const auto& it : vars) { - ICHECK_GT(metadata_.count(it), 0U) << "Found not recorded constant variable: " << it; - ret.push_back(metadata_[it]); + ICHECK_GT(const_var_ndarray_.count(it), 0U) << "Found not recorded constant variable: " << it; + ret.push_back(const_var_ndarray_[it]); } return ret; } @@ -102,7 +102,7 @@ class MetadataModuleNode : public ModuleNode { * for runtime lookup. * * \note A module could be like the following: - * MetadataModuleNode (contains all the metadata) + * ConstLoaderModuleNode (contains all the metadata) * - CSourceModule * - JSON runtime module * @@ -128,32 +128,32 @@ class MetadataModuleNode : public ModuleNode { void SaveToBinary(dmlc::Stream* stream) final { std::vector variables; - std::vector metadata; - for (const auto& it : metadata_) { + std::vector const_var_ndarray; + for (const auto& it : const_var_ndarray_) { String var_name = it.first; variables.push_back(var_name); - metadata.push_back(it.second); + const_var_ndarray.push_back(it.second); } // Save all variables in the function. stream->Write(variables); // Save all constant data. - uint64_t sz = static_cast(metadata.size()); + uint64_t sz = static_cast(const_var_ndarray.size()); stream->Write(sz); for (uint64_t i = 0; i < sz; i++) { - metadata[i].Save(stream); + const_var_ndarray[i].Save(stream); } // Save the symbol to list of required constant variables mapping std::vector symbols; std::vector> const_vars; - for (const auto& it : sym_vars_) { + for (const auto& it : const_vars_by_symbol_) { symbols.push_back(it.first); const_vars.push_back(it.second); } stream->Write(symbols); - sz = static_cast(sym_vars_.size()); + sz = static_cast(const_vars_by_symbol_.size()); stream->Write(sz); for (uint64_t i = 0; i < sz; i++) { stream->Write(const_vars[i]); @@ -165,9 +165,9 @@ class MetadataModuleNode : public ModuleNode { // Load the variables. std::vector variables; - ICHECK(stream->Read(&variables)) << "Loading variables failed"; + ICHECK(stream->Read(&variables)) << "Loading variable names failed"; uint64_t sz; - ICHECK(stream->Read(&sz, sizeof(sz))) << "Loading metadata size failed"; + ICHECK(stream->Read(&sz, sizeof(sz))) << "Loading number of vars failed"; ICHECK_EQ(static_cast(sz), variables.size()) << "The number of variables and ndarray counts must match"; // Load the list of ndarray. @@ -178,10 +178,10 @@ class MetadataModuleNode : public ModuleNode { arrays.push_back(temp); } - std::unordered_map metadata; + std::unordered_map const_var_ndarray; for (uint64_t i = 0; i < sz; i++) { - ICHECK_EQ(metadata.count(variables[i]), 0U); - metadata[variables[i]] = arrays[i]; + ICHECK_EQ(const_var_ndarray.count(variables[i]), 0U); + const_var_ndarray[variables[i]] = arrays[i]; } // Load the symbol to list of required constant variables mapping @@ -196,12 +196,12 @@ class MetadataModuleNode : public ModuleNode { const_vars.push_back(vars); } - std::unordered_map> sym_vars; + std::unordered_map> const_vars_by_symbol; for (uint64_t i = 0; i < sz; i++) { - sym_vars[symbols[i]] = const_vars[i]; + const_vars_by_symbol[symbols[i]] = const_vars[i]; } - auto n = make_object(metadata, sym_vars); + auto n = make_object(const_var_ndarray, const_vars_by_symbol); return Module(n); } @@ -212,19 +212,21 @@ class MetadataModuleNode : public ModuleNode { */ std::unordered_map initialized_; /*! \brief Variable name to NDArray mapping. */ - std::unordered_map metadata_; + std::unordered_map const_var_ndarray_; /*! \brief Symbol name to required constant variables mapping. */ - std::unordered_map> sym_vars_; + std::unordered_map> const_vars_by_symbol_; }; -Module MetadataModuleCreate( - const std::unordered_map& metadata, - const std::unordered_map>& sym_vars) { - auto n = make_object(metadata, sym_vars); +Module ConstLoaderModuleCreate( + const std::unordered_map& const_var_ndarray, + const std::unordered_map>& const_vars_by_symbol) { + auto n = make_object(const_var_ndarray, const_vars_by_symbol); return Module(n); } TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata") - .set_body_typed(MetadataModuleNode::LoadFromBinary); + .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader") + .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/const_loader_module.h b/src/runtime/const_loader_module.h new file mode 100644 index 000000000000..bd88f15c5bcc --- /dev/null +++ b/src/runtime/const_loader_module.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file const_loader_module.h + * \brief Defines an interface to use the ConstLoaderModule. + */ + +#ifndef TVM_RUNTIME_CONST_LOADER_MODULE_H_ +#define TVM_RUNTIME_CONST_LOADER_MODULE_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Create a ConstLoader module object. + * + * \param const_var_ndarray Maps consts var name to NDArray containing data for the var. + * \param const_vars_by_symbol Maps the name of a module init function to a list of names of + * const vars whose data will be passed to that init function. + * + * \return The created ConstLoaderModule. + */ +Module ConstLoaderModuleCreate( + const std::unordered_map& const_var_ndarray, + const std::unordered_map>& const_vars_by_symbol); + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONST_LOADER_MODULE_H_ diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 8996d1b76e1f..228555ebafda 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -96,19 +96,6 @@ class Metadata : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(MetadataNode); }; -/*! - * \brief Create a metadata module object. - * - * \param metadata The variable name to ndarray mapping. - * \param sym_vars The symbol to the list of required constant variables - * mapping. - * - * \return The created metadata module. - */ -Module MetadataModuleCreate( - const std::unordered_map& metadata, - const std::unordered_map>& sym_vars); - /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 2b190e5d66ed..2aa4fe5e234d 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -35,7 +35,7 @@ namespace tvm { namespace codegen { runtime::Module CreateMetadataModule( - const std::unordered_map& params, + const std::unordered_map& const_var_ndarray, tvm::runtime::Module target_module, const Array& ext_modules, Target target, tvm::relay::Runtime runtime, runtime::Metadata metadata) { // Here we split modules into two groups: @@ -52,19 +52,19 @@ runtime::Module CreateMetadataModule( bool is_targeting_crt = runtime->name == "crt"; // Wrap all submodules in the initialization wrapper. - std::unordered_map> sym_metadata; + std::unordered_map> const_vars_by_symbol; for (tvm::runtime::Module mod : ext_modules) { auto pf_sym = mod.GetFunction("get_symbol"); auto pf_var = mod.GetFunction("get_const_vars"); - std::vector arrays; + std::vector symbol_const_vars; if (pf_sym != nullptr && pf_var != nullptr) { String symbol = pf_sym(); Array variables = pf_var(); for (size_t i = 0; i < variables.size(); i++) { - arrays.push_back(variables[i].operator std::string()); + symbol_const_vars.push_back(variables[i].operator std::string()); } - ICHECK_EQ(sym_metadata.count(symbol), 0U) << "Found duplicated symbol: " << symbol; - sym_metadata[symbol] = arrays; + ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) << "Found duplicated symbol: " << symbol; + const_vars_by_symbol[symbol] = symbol_const_vars; } // We only need loading of serialized constant data // if there are constants present and required by the @@ -74,7 +74,7 @@ runtime::Module CreateMetadataModule( // TODO(@manupa-arm) : we should be able to use csource_metadata // if the variables are empty when all the runtime modules implement get_func_names - if (arrays.empty() && is_targeting_crt && DSOExportable(mod) && + if (symbol_const_vars.empty() && is_targeting_crt && DSOExportable(mod) && (target->kind->name == "c" || target->kind->name == "llvm")) { crt_exportable_modules.push_back(mod); } else { @@ -116,12 +116,12 @@ runtime::Module CreateMetadataModule( } } else { if (!non_crt_exportable_modules.empty()) { - runtime::Module binary_meta_mod = runtime::MetadataModuleCreate(params, sym_metadata); - binary_meta_mod.Import(target_module); + runtime::Module binary_const_loader_mod = runtime::ConstLoaderModuleCreate(const_var_ndarray, const_vars_by_symbol); + binary_const_loader_mod.Import(target_module); for (const auto& it : non_crt_exportable_modules) { - binary_meta_mod.Import(it); + binary_const_loader_mod.Import(it); } - return binary_meta_mod; + return binary_const_loader_mod; } } return target_module; From f33830568d07555af4b18033f2bd8cd5ccddcfef Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 29 Nov 2021 19:49:18 -0800 Subject: [PATCH 05/41] Add new Metadata classes and base implementation. * These were autogenerated in the original PR, but checking them in as plain code until we can revisit the auto-generator approach. --- include/tvm/runtime/metadata.h | 127 +++++++++++++++++ include/tvm/runtime/metadata_base.h | 190 +++++++++++++++++++++++++ include/tvm/support/span.h | 95 +++++++++++++ src/runtime/metadata.cc | 54 ++++++++ src/target/metadata.cc | 44 ++++++ src/target/metadata.h | 207 ++++++++++++++++++++++++++++ 6 files changed, 717 insertions(+) create mode 100644 include/tvm/runtime/metadata.h create mode 100644 include/tvm/runtime/metadata_base.h create mode 100644 include/tvm/support/span.h create mode 100644 src/runtime/metadata.cc create mode 100644 src/target/metadata.cc create mode 100644 src/target/metadata.h diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h new file mode 100644 index 000000000000..c4911a179bb0 --- /dev/null +++ b/include/tvm/runtime/metadata.h @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/metadata.h + * \brief Defines types which can be used in Metadata. + */ +#ifndef TVM_RUNTIME_METADATA_H_ +#define TVM_RUNTIME_METADATA_H_ + +#include +#include +#include +#include + +#define TVM_METADATA_VERSION 1 +static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION; +#ifdef __cplusplus +extern "C" { +#endif + +struct TVMMetadata { + int64_t version; + const struct TVMTensorInfo* inputs; + int64_t num_inputs; + const struct TVMTensorInfo* outputs; + int64_t num_outputs; + const char** devices; + int64_t num_devices; + const char* executor; + const char* mod_name; + const char* interface_api; + bool use_unpacked_api; +}; + +struct TVMTensorInfo { + const char* name; + const int64_t* shape; + int64_t num_shape; + DLDataType dtype; +}; +#ifdef __cplusplus +} // extern "C" +#include +namespace tvm { +namespace runtime { +namespace metadata { + +class Metadata; +class TensorInfo; + +class MetadataNode : public MetadataBaseNode { + public: + MetadataNode(const struct ::TVMMetadata* data) : data_{data} {} + static constexpr const char* _type_key = "metadata.MetadataNode"; + std::string get_name() override; + inline int64_t version() const { return int64_t(data_->version); } + inline int64_t num_inputs() const { return data_->num_inputs; } + ArrayAccessor inputs(); + inline int64_t num_outputs() const { return data_->num_outputs; } + ArrayAccessor outputs(); + inline int64_t num_devices() const { return data_->num_devices; } + ArrayAccessor devices(); + inline ::tvm::runtime::String executor() const { return ::tvm::runtime::String(data_->executor); } + inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); } + inline ::tvm::runtime::String interface_api() const { return ::tvm::runtime::String(data_->interface_api); } + inline bool use_unpacked_api() const { return bool(data_->use_unpacked_api); } + const struct ::TVMMetadata* data() const { return data_; } + TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode); + private: + const struct ::TVMMetadata* data_; + ::std::shared_ptr<::std::vector> inputs_refs_; + ::std::shared_ptr<::std::vector> outputs_refs_; + ::std::shared_ptr<::std::vector<::tvm::runtime::String>> devices_refs_; +}; + +class Metadata : public MetadataBase { + public: + Metadata(const struct ::TVMMetadata* data); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode); +}; + +class TensorInfoNode : public MetadataBaseNode { + public: + TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {} + static constexpr const char* _type_key = "metadata.TensorInfoNode"; + std::string get_name() override; + inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); } + inline int64_t num_shape() const { return data_->num_shape; } + inline ::tvm::support::Span shape() const { + return ::tvm::support::Span(data_->shape, data_->shape + data_->num_shape); + } + inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); } + const struct ::TVMTensorInfo* data() const { return data_; } + TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode); + private: + const struct ::TVMTensorInfo* data_; +}; + +class TensorInfo : public MetadataBase { + public: + TensorInfo(const struct ::TVMTensorInfo* data); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode); +}; + +} // namespace metadata +} // namespace runtime +} // namespace tvm +#endif // defined(__cplusplus) + +#endif // TVM_RUNTIME_METADATA_H_ diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h new file mode 100644 index 000000000000..4386818ec298 --- /dev/null +++ b/include/tvm/runtime/metadata_base.h @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/metadata_base.h + * \brief Defines types which can be used in Metadata. + */ +#ifndef TVM_RUNTIME_METADATA_BASE_H_ +#define TVM_RUNTIME_METADATA_BASE_H_ + +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace metadata { + +class MetadataBaseNode : public ::tvm::runtime::Object { + public: + virtual std::string get_name() = 0; + + static constexpr const char* _type_key = "metadata.MetadataBaseNode"; + TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object); +}; + +class MetadataBase : public ::tvm::runtime::ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode); +}; + +template +class ArrayAccessor; + +template +class ArrayIterator { + public: + ArrayIterator(size_t index, ArrayAccessor* parent) : index_{index}, parent_{parent} {} + + inline Ref operator*() { + return (*parent_)[index_]; + } + + inline ArrayIterator& operator++() { + if (index_ < parent_->size()) { + index_++; + } + + return *this; + } + + inline bool operator==(const ArrayIterator& other) { + return parent_ == other.parent_ && index_ == other.index_; + } + + inline bool operator!=(const ArrayIterator& other) { + return !operator==(other); + } + + private: + size_t index_; + ArrayAccessor* parent_; +}; + +template +class ArrayAccessor { + public: + + template ::value>::type> + ArrayAccessor(const C* data, size_t num_data, ::std::shared_ptr<::std::vector> refs) : data_{data}, num_data_{num_data}, refs_{refs} {} + + inline size_t size() { return num_data_; } + + inline Ref operator[](size_t index) { + if (index >= num_data_) { + throw std::runtime_error("Index out of range"); + } + + if (refs_->size() <= index) { + refs_->resize(num_data_); + } + + if (!(*refs_)[index].defined()) { + (*refs_)[index] = Ref(&data_[index]); + } + + return (*refs_)[index]; + } + + inline ArrayIterator begin() { + return ArrayIterator{0, this}; + } + + inline ArrayIterator end() { + return ArrayIterator{num_data_, this}; + } + + private: + const C* data_; + size_t num_data_; + ::std::shared_ptr<::std::vector> refs_; +}; + +template <> +class ArrayAccessor { + public: + ArrayAccessor(const char** data, size_t num_data, ::std::shared_ptr> refs) : data_{data}, num_data_{num_data}, refs_{refs} {} + + inline size_t size() { return num_data_; } + + inline ::tvm::runtime::String operator[](size_t index) { + if (index >= num_data_) { + throw std::runtime_error("Index out of range"); + } + + if (refs_->size() <= index) { + refs_->resize(num_data_); + } + + if (!(*refs_)[index].defined()) { + (*refs_)[index] = ::tvm::runtime::String(data_[index]); + } + + return (*refs_)[index]; + } + + inline ArrayIterator begin() { + return ArrayIterator{0, this}; + } + + inline ArrayIterator end() { + return ArrayIterator{num_data_, this}; + } + + private: + const char** data_; + size_t num_data_; + ::std::shared_ptr<::std::vector<::tvm::runtime::String>> refs_; +}; + +enum MetadataTypeIndex : uint8_t { + kUint64 = 0, + kInt64 = 1, + kBool = 2, + kString = 3, + kHandle = 4, + +}; + +class MetadataArrayNode : public MetadataBaseNode { + public: +// MetadataArray(Array array, MetadataTypeIndex type_index) : array{array}, type_index{type_index} {} + MetadataArrayNode(Array array, const char* c_type) : array(std::move(array)), c_type{c_type} {} + + std::string get_name() override; + + Array array; + const char* c_type; + static constexpr const char* _type_key = "metadata.MetadataArrayNode"; + TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode); +}; + +class MetadataArray : public MetadataBase { + public: +// MetadataArray(Array array, MetadataTypeIndex type_index); + MetadataArray(Array array, const char* c_type); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); +}; + +} // namespace metadata +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_METADATA_BASE_H_ diff --git a/include/tvm/support/span.h b/include/tvm/support/span.h new file mode 100644 index 000000000000..36c86db6fd5e --- /dev/null +++ b/include/tvm/support/span.h @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file span.h + * \brief Reimplementation of part of C++-20 style span. + */ +#ifndef TVM_SUPPORT_SPAN_H_ +#define TVM_SUPPORT_SPAN_H_ + +#include +#include +#include + +namespace tvm { +namespace support { + +template //, std::enable_if_t::value>::value> = true> +class Span { + public: + class iterator : public std::iterator { + public: + inline iterator(T* ptr, T* end) : ptr_{ptr}, end_{end} { + CHECK_GE(end, ptr); + } + + inline W operator*() { + return W(*ptr_); + } + + inline iterator& operator++() { + if (ptr_ != end_) ptr_++; + return *this; + } + + inline bool operator==(iterator other) { + return ptr_ == other.ptr_ && end_ == other.end_; + } + + inline bool operator!=(iterator other) { + return !(*this == other); + } + + private: + T* ptr_; + T* end_; + }; + + inline Span(T* begin, int num_elements) : begin_{begin}, end_{begin + num_elements} {} + inline Span(T* begin, T* end) : begin_{begin}, end_{end} {} + + inline iterator begin() { + return iterator(begin_, end_); + } + + inline iterator end() { + return iterator(end_, end_); + } + + inline W operator[](int i) { + T* to_return = begin_ + i; + ICHECK_LT(to_return, end_) << "Span access out of bounds: " << i; + return W(*to_return); + } + + inline operator std::vector() { + return std::vector(begin(), end()); + } + + protected: + T* begin_; + T* end_; +}; + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_SPAN_H_ diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc new file mode 100644 index 000000000000..a08e30333a52 --- /dev/null +++ b/src/runtime/metadata.cc @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file metadata.cc + * \brief Implementations of the runtime component of Metadata. + */ + +#include + +namespace tvm { +namespace runtime { +namespace metadata { + +ArrayAccessor MetadataNode::inputs() { + if (inputs_refs_.get() == nullptr) { inputs_refs_.reset(new ::std::vector()); } + return ArrayAccessor(data_->inputs, data_->num_inputs, inputs_refs_); +} +ArrayAccessor MetadataNode::outputs() { + if (outputs_refs_.get() == nullptr) { outputs_refs_.reset(new ::std::vector()); } + return ArrayAccessor(data_->outputs, data_->num_outputs, outputs_refs_); +} +ArrayAccessor MetadataNode::devices() { + if (devices_refs_.get() == nullptr) { devices_refs_.reset(new ::std::vector<::tvm::runtime::String>()); } + return ArrayAccessor(data_->devices, data_->num_devices, devices_refs_); +} +Metadata::Metadata(const struct ::TVMMetadata* data) : + MetadataBase{make_object(data)} {} +std::string MetadataNode::get_name() { return std::string{"Metadata"}; } +TVM_REGISTER_OBJECT_TYPE(MetadataNode); +TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) : + MetadataBase{make_object(data)} {} +std::string TensorInfoNode::get_name() { return std::string{"TensorInfo"}; } +TVM_REGISTER_OBJECT_TYPE(TensorInfoNode); + +} // namespace metadata +} // namespace runtime +} // namespace tvm diff --git a/src/target/metadata.cc b/src/target/metadata.cc new file mode 100644 index 000000000000..193f63c5133a --- /dev/null +++ b/src/target/metadata.cc @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file metadata.cc + * \brief Implementations of the compiler extensions for Metadata. + */ + +#include "metadata.h" +#include + +namespace tvm { +namespace target { +namespace metadata { + +TVM_REGISTER_REFLECTION_VTABLE(VisitableMetadataNode, ::tvm::detail::ReflectionTrait) +.set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); +}); + +TVM_REGISTER_REFLECTION_VTABLE(VisitableTensorInfoNode, ::tvm::detail::ReflectionTrait) +.set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); +}); + +} // namespace metadata +} // namespace target +} // namespace tvm diff --git a/src/target/metadata.h b/src/target/metadata.h new file mode 100644 index 000000000000..1eaaeab8e59f --- /dev/null +++ b/src/target/metadata.h @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/metadata.h + * \brief Extends Metadata for use in the compiler. + */ +#ifndef TVM_TARGET_METADATA_H +#define TVM_TARGET_METADATA_H + +#include +#include +#include +#include + +namespace tvm { +namespace target { +namespace metadata { + +class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { + public: + explicit VisitableMetadataNode(const struct ::TVMMetadata* data) : MetadataNode{data} {} + explicit VisitableMetadataNode() : MetadataNode{nullptr} {} + + void VisitAttrs(AttrVisitor* v) { + int64_t version_cpp{version()}; + v->Visit("version", &version_cpp); + auto inputs_array = Array(); + auto inputs_accessor = inputs(); + inputs_array.reserve(num_inputs()); + for (int64_t i = 0; i < num_inputs(); ++i) { + inputs_array.push_back(TensorInfo{inputs_accessor[i]}); + } + ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{inputs_array, "struct TVMTensorInfo"}; + v->Visit("inputs", &inputs_metadata_array); + auto outputs_array = Array(); + auto outputs_accessor = outputs(); + outputs_array.reserve(num_outputs()); + for (int64_t i = 0; i < num_outputs(); ++i) { + outputs_array.push_back(TensorInfo{outputs_accessor[i]}); + } + ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{outputs_array, "struct TVMTensorInfo"}; + v->Visit("outputs", &outputs_metadata_array); + auto devices_array = Array(); + auto devices_accessor = devices(); + devices_array.reserve(num_devices()); + for (int64_t i = 0; i < num_devices(); ++i) { + devices_array.push_back(::tvm::runtime::String{devices_accessor[i]}); + } + ::tvm::runtime::metadata::MetadataArray devices_metadata_array{devices_array, "const char*"}; + v->Visit("devices", &devices_metadata_array); + ::std::string executor_cpp{data()->executor}; + v->Visit("executor", &executor_cpp); + ::std::string mod_name_cpp{data()->mod_name}; + v->Visit("mod_name", &mod_name_cpp); + ::std::string interface_api_cpp{data()->interface_api}; + v->Visit("interface_api", &interface_api_cpp); + bool use_unpacked_api_cpp{use_unpacked_api()}; + v->Visit("use_unpacked_api", &use_unpacked_api_cpp); + } +}; + +class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNode { + public: + InMemoryMetadataNode() : InMemoryMetadataNode( + 0 /* version */, + {} /* inputs */, + {} /* outputs */, + {} /* devices */, + "" /* executor */, + "" /* mod_name */, + "" /* interface_api */, + false /* use_unpacked_api */ + ) {} + InMemoryMetadataNode( + int64_t version, + const ::std::vector& inputs, + const ::std::vector& outputs, + const ::std::vector<::std::string>& devices, + const ::tvm::runtime::String executor, + const ::tvm::runtime::String mod_name, + const ::tvm::runtime::String interface_api, + bool use_unpacked_api + ) : + inputs_{new struct TVMTensorInfo[inputs.size()]()}, + inputs_objs_{inputs}, + outputs_{new struct TVMTensorInfo[outputs.size()]()}, + outputs_objs_{outputs}, + devices_{new const char*[devices.size()]()}, + executor_{executor}, + mod_name_{mod_name}, + interface_api_{interface_api}, + storage_{ + version, + NULL, NULL, + NULL, NULL, + NULL, NULL, + executor_.c_str(), + mod_name_.c_str(), + interface_api_.c_str(), + use_unpacked_api + }, + VisitableMetadataNode{&storage_} { + storage_.inputs = inputs_.get(); + storage_.num_inputs = inputs.size(); + for (int i = 0; i < inputs.size(); ++i) { + inputs_.get()[i] = *inputs[i]->data(); + } + storage_.outputs = outputs_.get(); + storage_.num_outputs = outputs.size(); + for (int i = 0; i < outputs.size(); ++i) { + outputs_.get()[i] = *outputs[i]->data(); + } + storage_.devices = devices_.get(); + storage_.num_devices = devices.size(); + for (int i = 0; i < devices.size(); ++i) { + devices_.get()[i] = devices[i].c_str(); + } + } + + private: + ::std::unique_ptr inputs_; + std::vector inputs_objs_; + ::std::unique_ptr outputs_; + std::vector outputs_objs_; + ::std::unique_ptr devices_; + ::std::string executor_; + ::std::string mod_name_; + ::std::string interface_api_; + struct ::TVMMetadata storage_; +}; + +class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode { + public: + explicit VisitableTensorInfoNode(const struct ::TVMTensorInfo* data) : TensorInfoNode{data} {} + explicit VisitableTensorInfoNode() : TensorInfoNode{nullptr} {} + + void VisitAttrs(AttrVisitor* v) { + ::std::string name_cpp{data()->name}; + v->Visit("name", &name_cpp); + auto shape_array = Array(); + auto shape_accessor = shape(); + shape_array.reserve(num_shape()); + for (int64_t i = 0; i < num_shape(); ++i) { + shape_array.push_back(::tvm::Integer{shape_accessor[i]}); + } + ::tvm::runtime::metadata::MetadataArray shape_metadata_array{shape_array, "int64_t"}; + v->Visit("shape", &shape_metadata_array); + ::tvm::runtime::DataType dtype_cpp{dtype()}; + v->Visit("dtype", &dtype_cpp); + } +}; + +class InMemoryTensorInfoNode : public ::tvm::target::metadata::VisitableTensorInfoNode { + public: + InMemoryTensorInfoNode() : InMemoryTensorInfoNode( + "", + {}, + ::tvm::runtime::DataType(0, 0, 0) + ) {} + InMemoryTensorInfoNode( + const ::tvm::runtime::String& name, + const ::std::vector& shape, + ::tvm::runtime::DataType dtype + ) : + name_{name}, + shape_{new int64_t[shape.size()]()}, + storage_{ + name_.c_str(), + NULL, NULL, + dtype + }, + VisitableTensorInfoNode{&storage_} { + storage_.shape = shape_.get(); + storage_.num_shape = shape.size(); + for (int i = 0; i < shape.size(); ++i) { + shape_.get()[i] = shape[i]; + } + } + + private: + ::std::string name_; + ::std::unique_ptr shape_; + struct ::TVMTensorInfo storage_; +}; + +} // namespace metadata +} // namespace runtime +} // namespace tvm + +#endif // TVM_TARGET_METADATA_H From 86cc6ace2c3d823956c2e2ff337b04df2349b12e Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 29 Nov 2021 20:40:22 -0800 Subject: [PATCH 06/41] Add runtime AOT executor module. --- CMakeLists.txt | 8 + python/tvm/relay/backend/executor_factory.py | 10 + python/tvm/runtime/__init__.py | 2 + python/tvm/runtime/executor/__init__.py | 2 + python/tvm/runtime/executor/aot_executor.py | 182 +++++++++++++++++ src/runtime/aot_executor/aot_executor.cc | 190 ++++++++++++++++++ src/runtime/aot_executor/aot_executor.h | 151 ++++++++++++++ .../aot_executor/aot_executor_factory.cc | 133 ++++++++++++ .../aot_executor/aot_executor_factory.h | 119 +++++++++++ 9 files changed, 797 insertions(+) create mode 100644 python/tvm/runtime/executor/__init__.py create mode 100644 python/tvm/runtime/executor/aot_executor.py create mode 100644 src/runtime/aot_executor/aot_executor.cc create mode 100644 src/runtime/aot_executor/aot_executor.h create mode 100644 src/runtime/aot_executor/aot_executor_factory.cc create mode 100644 src/runtime/aot_executor/aot_executor_factory.h diff --git a/CMakeLists.txt b/CMakeLists.txt index c667750ed5d6..f56e8946d398 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" O tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_EXECUTOR "Build with tiny graph executor" ON) tvm_option(USE_GRAPH_EXECUTOR_CUDA_GRAPH "Build with tiny graph executor with CUDA Graph for GPUs" OFF) +tvm_option(USE_AOT_EXECUTOR "Build with AOT executor" ON) tvm_option(USE_PROFILER "Build profiler for the VM and graph executor" ON) tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) @@ -392,6 +393,13 @@ if(USE_PROFILER) list(APPEND RUNTIME_SRCS ${RUNTIME_VM_PROFILER_SRCS}) endif(USE_PROFILER) +if(USE_AOT_EXECUTOR) + message(STATUS "Build with AOT Executor support...") + file(GLOB RUNTIME_AOT_EXECUTOR_SRCS src/runtime/aot_executor/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_AOT_EXECUTOR_SRCS}) + +endif(USE_AOT_EXECUTOR) + # Enable ctest if gtest is available if(USE_GTEST) # Check env var for backward compatibility. A better way to specify package diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index 5f4a134270ac..b836ce914696 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -105,6 +105,13 @@ def __init__( function_metadata, devices, ): + fcreate = get_global_func("tvm.aot_executor_factory.create") + args = [] + for k, v in params.items(): + args.append(k) + args.append(ndarray.array(v)) + + self.module = fcreate(libmod, libmod_name, *args) self.ir_mod = ir_mod self.lowered_ir_mods = lowered_ir_mods self.target = target @@ -128,6 +135,9 @@ def get_executor_config(self): def get_lib(self): return self.lib + def export_library(self, file_name, fcompile=None, addons=None, **kwargs): + return self.module.export_library(file_name, fcompile, addons, **kwargs) + class GraphExecutorFactoryModule(ExecutorFactoryModule): """Graph executor factory module. diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index b3504dbac506..ab0fc1709fa9 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -31,3 +31,5 @@ from .module import load_module, enabled, system_lib from .container import String, ShapeTuple from .params import save_param_dict, load_param_dict + +from . import executor diff --git a/python/tvm/runtime/executor/__init__.py b/python/tvm/runtime/executor/__init__.py new file mode 100644 index 000000000000..92a0402549ec --- /dev/null +++ b/python/tvm/runtime/executor/__init__.py @@ -0,0 +1,2 @@ + +from .aot_executor import AotModule diff --git a/python/tvm/runtime/executor/aot_executor.py b/python/tvm/runtime/executor/aot_executor.py new file mode 100644 index 000000000000..91f056ff25fa --- /dev/null +++ b/python/tvm/runtime/executor/aot_executor.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A Python wrapper for the Module-based Model Runtime Interface for Ahead-of-Time compilation.""" + +import numpy as np + + +class AotModule(object): + """Wraps the AOT executor runtime.Module. + + This is a thin wrapper of the underlying TVM module. + you can also directly call set_input, run, and get_output + of underlying module functions + + Parameters + ---------- + module : tvm.runtime.Module + The internal tvm module that holds the actual graph functions. + + Attributes + ---------- + module : tvm.runtime.Module + The internal tvm module that holds the actual graph functions. + + Examples + -------- + + .. code-block:: python + + import tvm + from tvm import relay + from tvm.contrib import graph_executor + + # build the library using graph executor + lib = relay.build(...) + lib.export_library("compiled_lib.so") + # load it back as a runtime + lib: tvm.runtime.Module = tvm.runtime.load_module("compiled_lib.so") + # Call the library factory function for default and create + # a new runtime.Module, wrap with graph module. + gmod = graph_executor.GraphModule(lib["default"](dev)) + # use the graph module. + gmod.set_input("x", data) + gmod.run() + """ + + def __init__(self, module): + self.module = module + self._set_input = module["set_input"] + self._run = module["run"] + self._get_output = module["get_output"] + self._get_input = module["get_input"] + self._get_num_outputs = module["get_num_outputs"] + self._get_input_index = module["get_input_index"] + self._get_num_inputs = module["get_num_inputs"] + + def set_input(self, key=None, value=None, **params): + """Set inputs to the module via kwargs + + Parameters + ---------- + key : int or str + The input key + + value : the input value. + The input key + + params : dict of str to NDArray + Additional arguments + """ + if key is not None: + v = self._get_input(key) + if v is None: + raise RuntimeError("Could not find '%s' in graph's inputs" % key) + v.copyfrom(value) + + if params: + # upload big arrays first to avoid memory issue in rpc mode + keys = list(params.keys()) + keys.sort(key=lambda x: -np.prod(params[x].shape)) + for k in keys: + # TODO(zhiics) Skip the weights for submodule in a better way. + # We should use MetadataModule for initialization and remove + # params from set_input + val = self._get_input(k) + if val: + self._get_input(k).copyfrom(params[k]) + + def run(self, **input_dict): + """Run forward execution of the graph + + Parameters + ---------- + input_dict: dict of str to NDArray + List of input values to be feed to + """ + if input_dict: + self.set_input(**input_dict) + self._run() + + def get_num_outputs(self): + """Get the number of outputs from the graph + + Returns + ------- + count : int + The number of outputs. + """ + return self._get_num_outputs() + + def get_num_inputs(self): + """Get the number of inputs to the graph + + Returns + ------- + count : int + The number of inputs. + """ + return self._get_num_inputs() + + def get_input(self, index, out=None): + """Get index-th input to out + + Parameters + ---------- + index : int + The input index + + out : NDArray + The output array container + """ + if out: + self._get_input(index).copyto(out) + return out + + return self._get_input(index) + + def get_input_index(self, name): + """Get inputs index via input name. + + Parameters + ---------- + name : str + The input key name + + Returns + ------- + index: int + The input index. -1 will be returned if the given input name is not found. + """ + return self._get_input_index(name) + + def get_output(self, index, out=None): + """Get index-th output to out + + Parameters + ---------- + index : int + The output index + + out : NDArray + The output array container + """ + if out: + self._get_output(index, out) + return out + + return self._get_output(index) diff --git a/src/runtime/aot_executor/aot_executor.cc b/src/runtime/aot_executor/aot_executor.cc new file mode 100644 index 000000000000..cfa4c1e36ebd --- /dev/null +++ b/src/runtime/aot_executor/aot_executor.cc @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Defines an implementation of Module-based Model Runtime Interface that works with + * Ahead-of-Time compilation. + * \file aot_executor.cc + */ + +#include "aot_executor.h" + +#include + +namespace tvm { +namespace runtime { + +AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector& devs) : + module_{module}, devices_{devs} { + + auto fmetadata = module->GetFunction("get_metadata"); + CHECK(fmetadata != nullptr) << "Expected a module with PackedFunc get_metadata"; + auto ret_value = fmetadata(); + metadata_ = ret_value.AsObjectRef(); + + for (auto input : metadata_->inputs()) { + // TODO(areusch): Encode device information in Metadata. + args_.emplace_back(NDArray::Empty(ShapeTuple(input->shape().begin(), input->shape().end()), input->dtype(), devices_[0])); + } + + for (auto output : metadata_->outputs()) { + args_.emplace_back(NDArray::Empty(ShapeTuple(output->shape().begin(), output->shape().end()), output->dtype(), devices_[0])); + } +} + +PackedFunc AotExecutor::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + // Return member functions during query. + if (name == "set_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); + if (in_idx >= 0) this->SetInput(in_idx, args[1]); + } else { + this->SetInput(args[0], args[1]); + } + }); + } else if (name == "set_input_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); + if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); + } else { + this->SetInputZeroCopy(args[0], args[1]); + } + }); + } else if (name == "set_output_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int out_idx = this->GetOutputIndex(args[0].operator String()); + if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]); + } else { + this->SetOutputZeroCopy(args[0], args[1]); + } + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (args.num_args == 2) { + this->CopyOutputTo(args[0], args[1]); + } else { + *rv = this->GetOutput(args[0]); + } + }); + } else if (name == "get_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + int in_idx = 0; + if (String::CanConvertFrom(args[0])) { + in_idx = this->GetInputIndex(args[0].operator String()); + } else { + in_idx = args[0]; + } + if (in_idx >= 0) { + *rv = this->GetInput(in_idx); + } + }); + } else if (name == "get_num_outputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); + } else if (name == "get_num_inputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); }); + } else if (name == "run") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); + } else if (name == "get_input_index") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string"; + *rv = this->GetInputIndex(args[0].operator String()); + }); + } else { + return PackedFunc(); + } +} + +void AotExecutor::Run() { + LOG(INFO) << "Get entrypoint " << metadata_->mod_name() << "_run_model"; + auto pf = module_.GetFunction(metadata_->mod_name() + "_run_model", true /* query_imports */); + ICHECK(pf != nullptr) << "Module entrypoint is not defined"; + + const int num_args = args_.size(); + ::std::unique_ptr call_values{new TVMValue[num_args]}; + ::std::unique_ptr call_type_codes{new int[num_args]}; + for (int i = 0; i < num_args; ++i) { + auto managed = args_[i].ToDLPack(); + call_values.get()[i].v_handle = &managed->dl_tensor; + call_type_codes.get()[i] = kTVMDLTensorHandle; + } + + TVMArgs args{call_values, call_type_codes, num_args}; + TVMRetValue rv; + pf.CallPacked(args, &rv); +} + +int AotExecutor::GetInputIndex(const std::string& name) { + auto inputs = metadata_->inputs(); + for (unsigned int i = 0; i < inputs.size(); i++) { + if (inputs[i]->name() == name) { + return i; + } + } + return -1; +} + +int AotExecutor::GetOutputIndex(const std::string& name) { + auto outputs = metadata_->outputs(); + for (unsigned int i = 0; i < outputs.size(); i++) { + if (outputs[i]->name() == name) { + return i; + } + } + return -1; +} + +void AotExecutor::SetInput(int index, DLTensor* data_ref) { + args_[index].CopyFrom(data_ref); +} + +void AotExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) { + ICHECK(false) << "not implemented"; +} + +void AotExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) { + ICHECK(false) << "not implemented"; +} + +int AotExecutor::NumOutputs() const { + return metadata_->num_outputs(); +} + +int AotExecutor::NumInputs() const { + return metadata_->num_inputs(); +} + +NDArray AotExecutor::GetInput(int index) const { + return args_[index]; +} + +NDArray AotExecutor::GetOutput(int index) const { + return args_[metadata_->num_inputs() + index]; +} + +void AotExecutor::CopyOutputTo(int index, DLTensor* data_out) { + GetOutput(index).CopyTo(data_out); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/aot_executor/aot_executor.h b/src/runtime/aot_executor/aot_executor.h new file mode 100644 index 000000000000..591af78284e4 --- /dev/null +++ b/src/runtime/aot_executor/aot_executor.h @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Defines an implementation of Module-based Model Runtime Interface that works with + * Ahead-of-Time compilation. + * \file aot_executor.h + */ +#ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ +#define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ + +#include +#include +#include +#include +#include +#include + + +namespace tvm { +namespace runtime { + +class TVM_DLL AotExecutor : public ModuleNode { + + public: + /*! + * \brief Implements member function lookup for this Module for the frontend. + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; + + /*! + * \return The type key of the executor. + */ + const char* type_key() const final { return "AotExecutor"; } + + void Run(); + + /*! + * \brief Initialize the AOT executor with metadata, runtime::Module, and device. + * \param module The module containing the compiled functions for the host + * processor. + * \param devs The device of the host and devices where graph nodes will be + * executed on. + * \param lookup_linked_param_func If given, a PackedFunc invoked to lookup linked parameters + * by storage_id. If not given, linked parameters are looked-up using an internal implementation, + * which is not compatible with RPCModules. Default is nullptr. + */ + AotExecutor(tvm::runtime::Module module, const std::vector& devs); + + /*! + * \brief Get the input index given the name of input. + * \param name The name of the input. + * \return The index of input. + */ + int GetInputIndex(const std::string& name); + + /*! + * \brief Get the output index given the name of output. + * \param name The name of the output. + * \return The index of output. + */ + int GetOutputIndex(const std::string& name); + + /*! + * \brief set index-th input to the graph. + * \param index The input index. + * \param data_in The input data. + */ + void SetInput(int index, DLTensor* data_in); + /*! + * \brief set index-th input to the graph without copying the data + * \param index The input index. + * \param data_ref The input data that is referred. + */ + void SetInputZeroCopy(int index, DLTensor* data_ref); + /*! + * \brief set index-th output to the graph without copying the data. + * \param index The output index. + * \param data_ref The output data that is referred. + */ + void SetOutputZeroCopy(int index, DLTensor* data_ref); + /*! + * \brief Get the number of outputs + * + * \return The number of outputs from graph. + */ + int NumOutputs() const; + /*! + * \brief Get the number of inputs + * + * \return The number of inputs to the graph. + */ + int NumInputs() const; + /*! + * \brief Return NDArray for given input index. + * \param index The input index. + * + * \return NDArray corresponding to given input node index. + */ + NDArray GetInput(int index) const; + /*! + * \brief Return NDArray for given output index. + * \param index The output index. + * + * \return NDArray corresponding to given output node index. + */ + NDArray GetOutput(int index) const; + /*! + * \brief Copy index-th output to data_out. + * \param index The output index. + * \param data_out the output data. + */ + void CopyOutputTo(int index, DLTensor* data_out); + + private: + /*! \brief Metadata provided to the runtime from the compiler. */ + metadata::Metadata metadata_; + + /*! \brief Runtime module which contains the AOT top-level function. */ + Module module_; + + /*! \brief The devices which should be used to execute the computations. */ + std::vector devices_; + + /*! \brief Holds one NDArray per function argument in the same order. */ + std::vector args_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ diff --git a/src/runtime/aot_executor/aot_executor_factory.cc b/src/runtime/aot_executor/aot_executor_factory.cc new file mode 100644 index 000000000000..e8ded8573028 --- /dev/null +++ b/src/runtime/aot_executor/aot_executor_factory.cc @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file aot_executor_factory.cc + * \brief Graph executor factory implementations + */ + +#include "./aot_executor_factory.h" + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { + +AotExecutorFactory::AotExecutorFactory( + const std::unordered_map& params, + const std::string& module_name) { + params_ = params; + module_name_ = module_name; +} + +PackedFunc AotExecutorFactory::GetFunction( + const std::string& name, const tvm::runtime::ObjectPtr& sptr_to_self) { + if (name == module_name_) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_GT(args.num_args, 0) << "Must supply at least one device argument"; + std::vector devices; + for (int i = 0; i < args.num_args; ++i) { + devices.emplace_back(args[i].operator Device()); + } + *rv = this->ExecutorCreate(devices); + }); + } else if (name == "remove_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::unordered_map empty_params{}; + auto exec = make_object(empty_params, this->module_name_); + exec->Import(this->imports_[0]); + *rv = Module(exec); + }); + } else { + return PackedFunc(); + } +} + +void AotExecutorFactory::SaveToBinary(dmlc::Stream* stream) { + std::vector names; + std::vector arrays; + for (const auto& v : params_) { + names.emplace_back(v.first); + arrays.emplace_back(const_cast(v.second.operator->())); + } + uint64_t sz = arrays.size(); + ICHECK(sz == names.size()); + stream->Write(sz); + stream->Write(names); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(stream, arrays[i]); + } + stream->Write(module_name_); +} + +Module AotExecutorFactory::ExecutorCreate(const std::vector& devs) { + auto exec = make_object(this->imports_[0], devs); + // set params + SetParams(exec.get(), this->params_); + return Module(exec); +} + +Module AotExecutorFactoryModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::unordered_map params; + std::string module_name; + uint64_t sz; + ICHECK(stream->Read(&sz)); + std::vector names; + ICHECK(stream->Read(&names)); + ICHECK(sz == names.size()); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::NDArray temp; + temp.Load(stream); + params[names[i]] = temp; + } + ICHECK(stream->Read(&module_name)); + auto exec = make_object(params, module_name); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.aot_executor_factory.create") + .set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.num_args, 2) << "The expected number of arguments for " + "aot_executor_factory.create needs at least 2, " + "but it has " + << args.num_args; + // The argument order is module, module_name, param0_name, param0_tensor, + // [param1_name, param1_tensor], ... + ICHECK_EQ((args.size() - 2) % 2, 0); + std::unordered_map params; + for (size_t i = 2; i < static_cast(args.size()); i += 2) { + std::string name = args[i].operator String(); + params[name] = args[i + 1].operator tvm::runtime::NDArray(); + } + auto exec = make_object(params, args[1]); + exec->Import(args[0]); + *rv = Module(exec); + }); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_AotExecutorFactory") + .set_body_typed(AotExecutorFactoryModuleLoadBinary); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/aot_executor/aot_executor_factory.h b/src/runtime/aot_executor/aot_executor_factory.h new file mode 100644 index 000000000000..fbbebe1a4d86 --- /dev/null +++ b/src/runtime/aot_executor/aot_executor_factory.h @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/aot_executor/aot_executor_factory.h + * \brief Aot executor factory creating aot executor. + */ + +#ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_ +#define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "./aot_executor.h" + +namespace tvm { +namespace runtime { + +class TVM_DLL AotExecutorFactory : public runtime::ModuleNode { + public: + /*! + * \brief Construct the AotExecutorFactory. + * \param params The params of aot. + * \param module_name The module name of aot. + */ + AotExecutorFactory(const std::unordered_map& params, + const std::string& module_name); + + /*! + * \brief Get member function to front-end + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + /*! + * \return The type key of the executor. + */ + const char* type_key() const override { return "AotExecutorFactory"; } + + /*! + * \brief Save the module to binary stream. + * \param stream The binary stream to save to. + */ + void SaveToBinary(dmlc::Stream* stream) override; + + /*! + * \brief Create a specific executor module + * \param devs The device of the host and devices where the model will be + * executed. + * \return created executor module + */ + Module ExecutorCreate(const std::vector& devs); + + /*! + * \brief Set params. + * \param aot_executor The aot executor we want to set the params into. + * \param params The aot params value we want to set. + */ + void SetParams(AotExecutor* aot_executor, + const std::unordered_map& params) const { + std::unordered_map value = params; + // upload big arrays first to avoid memory issue in rpc mode + std::vector keys; + for (const auto& p : value) { + keys.emplace_back(p.first); + } + std::sort(std::begin(keys), std::end(keys), + [&](const std::string& lhs, const std::string& rhs) -> bool { + auto lhs_size = GetDataSize(*value[lhs].operator->()); + auto rhs_size = GetDataSize(*value[rhs].operator->()); + return lhs_size > rhs_size; + }); + for (const auto& key : keys) { + int in_idx = aot_executor->GetInputIndex(key); + if (in_idx >= 0) { + aot_executor->SetInput(in_idx, const_cast(value[key].operator->())); + } + } + } + + protected: + /*! \brief The params. */ + std::unordered_map params_; + /*! \brief module name */ + std::string module_name_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_FACTORY_H_ From 7bba41b76dc7dbd8ceafff42fdd062204f24e9a6 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 29 Nov 2021 20:44:07 -0800 Subject: [PATCH 07/41] Add AOT code-generation. --- src/target/source/codegen_c_host.cc | 8 + src/target/source/codegen_c_host.h | 1 + src/target/source/codegen_source_base.h | 24 ++ src/target/source/source_module.cc | 498 ++++++++++++++++++++++-- src/target/source/source_module.h | 13 +- 5 files changed, 518 insertions(+), 26 deletions(-) diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 515cdccb88fb..c1a763023bb3 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -51,6 +51,10 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_s CodeGenC::Init(output_ssa); } +void CodeGenCHost::InitGlobalContext() { + decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx << " = NULL;\n"; +} + void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } void CodeGenCHost::AddFunction(const PrimFunc& f) { @@ -437,6 +441,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) { cg.AddFunction(aot_executor_fn); } + if (aot_executor_fn.defined()) { + cg.InitGlobalContext(); + } + if (target->GetAttr("system-lib").value_or(Bool(false))) { ICHECK_EQ(target->GetAttr("runtime").value_or(""), "c") << "c target only supports generating C runtime SystemLibs"; diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index c94612cfeac3..8bd83444717d 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -40,6 +40,7 @@ class CodeGenCHost : public CodeGenC { CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts, std::string target_str); + void InitGlobalContext(); void AddFunction(const PrimFunc& f); void DefineModuleName(); diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index d938469b8969..7d55273376be 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -25,6 +25,7 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ #define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ +#include #include #include #include @@ -145,6 +146,19 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, const Array& func_names, const Array& const_vars = {}); +/*! + * \brief Wrap the submodules in a metadata module. + * \param params The variable to constant mapping that is collected by the host + * module. + * \param target_module The main TIR-lowered internal runtime module + * \param modules All the external modules that needs to be imported inside the metadata module(s). + * \param target The target that all the modules are compiled for + * \return The wrapped module. + */ +runtime::Module CreateMetadataModule( + const std::unordered_map& params, runtime::Module target_module, + const Array& ext_modules, Target target, runtime::metadata::Metadata metadata); + /*! * \brief Create a source module for viewing and limited saving for device. * \param data The code data to be viewed. @@ -157,6 +171,16 @@ runtime::Module DeviceSourceModuleCreate( std::string data, std::string fmt, std::unordered_map fmap, std::string type_key, std::function fget_source = nullptr); +/*! + * \brief Wrap the submodules that are to be wrapped in a c-source metadata module for C runtime. + * \param modules The modules to be wrapped. + * \param target the target the modules are compiled for. + * \param metadata the metadata needed for code generation. + * \return The wrapped module. + */ +runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, + runtime::metadata::Metadata metadata); + } // namespace codegen } // namespace tvm #endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index e01a3d93d087..b943d7cffba0 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -26,6 +26,7 @@ #include #include #include +#include "../metadata.h" #include #include @@ -130,7 +131,7 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { public: CSourceCrtMetadataModuleNode(const Array& func_names, const std::string& fmt, - Target target, relay::Runtime runtime, runtime::Metadata metadata) + Target target, relay::Runtime runtime, runtime::metadata::Metadata metadata) : fmt_(fmt), func_names_(func_names), target_(target), @@ -164,7 +165,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { Array func_names_; Target target_; relay::Runtime runtime_; - runtime::Metadata metadata_; + runtime::metadata::Metadata metadata_; void CreateFuncRegistry() { code_ << "#include \n"; @@ -200,7 +201,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name, const std::string& run_func) { code_ << "TVM_DLL int32_t " << run_func << "("; - unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs); + unsigned int total_args = (metadata_->num_inputs() + metadata_->num_outputs()); for (unsigned int i = 0; i < total_args; ++i) { code_ << "void* arg" << i; if (i + 1 != total_args) { @@ -246,7 +247,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "#include <" << mod_name << ".h>\n"; code_ << "TVM_DLL int32_t " << run_func << "("; unsigned int total_args = - (metadata_->inputs.size() + metadata_->devices.size() + metadata_->num_outputs); + (metadata_->num_inputs() + metadata_->num_devices() + metadata_->num_outputs()); for (unsigned int i = 0; i < total_args; ++i) { code_ << "void* arg" << i; if (i + 1 != total_args) { @@ -256,7 +257,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << ");\n"; code_ << "int32_t " << entrypoint_name << "("; code_ << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,"; - if (!metadata_->devices.empty()) { + if (metadata_->num_devices() > 0) { code_ << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs,"; code_ << "struct " << runtime::get_name_mangled(mod_name, "devices") << "* devices"; } else { @@ -265,27 +266,28 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << ") {" << "return " << run_func << "("; - for (const auto& input : metadata_->inputs) { - std::string sanitised_input = input; + for (const auto& input : metadata_->inputs()) { + std::string sanitised_input = input->name(); std::replace_if(sanitised_input.begin(), sanitised_input.end(), isNotAlnum, '_'); code_ << "inputs->" << sanitised_input << ","; } - if (metadata_->num_outputs == 1) { + if (metadata_->num_outputs() == 1) { code_ << "outputs->output"; } else { - for (int i = 0; i < metadata_->num_outputs; ++i) { + for (int i = 0; i < metadata_->num_outputs(); ++i) { code_ << "outputs->output" << i; - if (i + 1 != metadata_->num_outputs) { + if (i + 1 != metadata_->num_outputs()) { code_ << ","; } } } - if (!metadata_->devices.empty()) { + if (metadata_->num_devices() > 0) { code_ << ","; - for (const String& device : metadata_->devices) { + auto devices = metadata_->devices(); + for (const String& device : devices) { code_ << "devices->" << device; - if (device != metadata_->devices.back()) { + if (device != devices[devices.size() -1]) { code_ << ","; } } @@ -299,24 +301,24 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_run_func_suffix; const std::string tvm_entrypoint_suffix = ::tvm::runtime::symbol::tvm_entrypoint_suffix; const std::string run_func_mangled = - runtime::get_name_mangled(metadata_->mod_name, run_func_suffix); + runtime::get_name_mangled(metadata_->mod_name(), run_func_suffix); const std::string entrypoint_mangled = - runtime::get_name_mangled(metadata_->mod_name, tvm_entrypoint_suffix); - const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name, "network"); + runtime::get_name_mangled(metadata_->mod_name(), tvm_entrypoint_suffix); + const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name(), "network"); code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n"; code_ << "#ifdef __cplusplus\n"; code_ << "extern \"C\" {\n"; code_ << "#endif\n"; - if (metadata_->unpacked_api) { - if (metadata_->interface_api == "c") { - GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name); + if (metadata_->use_unpacked_api()) { + if (metadata_->interface_api() == "c") { + GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name()); } else { GenerateEntrypointForUnpackedAPI(entrypoint_mangled, run_func_mangled); } } else { - ICHECK_EQ(metadata_->interface_api, "packed") + ICHECK_EQ(metadata_->interface_api(), "packed") << "Packed interface required for packed operators"; GenerateEntrypointForPackedAPI(entrypoint_mangled, run_func_mangled); } @@ -331,15 +333,441 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { CreateFuncRegistry(); GenerateCrtSystemLib(); } - if (metadata_.defined() && metadata_->executor == runtime::kTvmExecutorAot) { + if (metadata_.defined() && metadata_->executor() == runtime::kTvmExecutorAot) { GenerateAOTDescriptor(); } code_ << ";"; } }; + +class CMetadataWriterVisitor : public ::tvm::AttrVisitor { + private: + std::stringstream struct_defs_; + + std::vector streams_; + std::stringstream* current_stream_; + + void Visit(const char* key, double* value) override { + (*current_stream_) << *value; + } + + void Visit(const char* key, int64_t* value) override { + (*current_stream_) << *value << "L"; + } + + void Visit(const char* key, uint64_t* value) override { + (*current_stream_) << *value << "UL"; + } + + void Visit(const char* key, int* value) override { + (*current_stream_) << *value; + } + + void Visit(const char* key, bool* value) override { + (*current_stream_) << (value ? "true" : "false"); + } + + void Visit(const char* key, std::string* value) override { + (*current_stream_) << "\"" << value << "\""; // todo: ->replace('\\', "\\\\").replace('\"', "\\\"") << "\""; + } + + void Visit(const char* key, void** value) override { + (*current_stream_) << *value; + } + + void Visit(const char* key, DataType* value) override { + (*current_stream_) << "DLDataType{" << value->code() << ", " << value->bits() << ", " << value->lanes() << "}"; + } + + void Visit(const char* key, runtime::NDArray* value) override { + ICHECK(false) << "at key " << key << ": cannot emit metadata of type NDArray"; + } + + void Visit(const char* key, runtime::ObjectRef* value) override { +// if (value->as< + // todo + } + +}; + +class MetadataStructDefiner : public AttrVisitor { + public: + + void Visit(const char* key, double* value) final { + // dns: mangle name + code_ << " double " << key << ";" << std::endl; + } + + void Visit(const char* key, int64_t* value) final { + // dns: mangle name + code_ << " int64_t " << key << ";" << std::endl; + } + + void Visit(const char* key, uint64_t* value) final { + // dns: mangle name + code_ << " uint64_t " << key << ";" << std::endl; + } + void Visit(const char* key, int* value) final { + // dns: mangle name + code_ << " int " << key << ";" << std::endl; + } + void Visit(const char* key, bool* value) final { + // dns: mangle name + code_ << " uint8_t " << key << ";" << std::endl; + } + void Visit(const char* key, std::string* value) final { + // dns: mangle name + code_ << " const char* " << key << ";" << std::endl; + } + void Visit(const char* key, void** value) final { + // dns: mangle name + code_ << " void* " << key << ";" << std::endl; + } + void Visit(const char* key, DataType* value) final { + // dns: mangle name + code_ << " DLDataType " << key << ";" << std::endl; + } + + void Visit(const char* key, runtime::NDArray* value) final { + // TODO(areusch): probably we could consolidate --link-params here, tho... + ICHECK(false) << "do not support serializing NDArray as metadata"; + } + + void WriteComma() { + if (!is_first_item_) { + code_ << ", "; + } + } + + void VisitArray(const char* key, const runtime::metadata::MetadataArrayNode* array) { + code_ << " " << array->c_type << "* " << key << ";" << std::endl; + } + // switch (array->type_index) { + // case MetadataTypeIndex::kUint64: + // code_ << " uint64_t** " << key << ";" << std::endl; + // case MetadataTypeIndex::kInt64: + // code_ << " int64_t** " << key << ";" << std::endl; + // case MetadataTypeIndex::kString: + // code_ << " const char** " << key << ";" << std::endl; + // case MetadataTypeIndex::kHandle: + // code_ << " void** " << key << ";" << std::endl; + // default: + // CHECK(false) << "Field " << key << ": unknown MetadataTypeIndex: " << array->type_index; + // } + // } + + + // const ArrayNode* arr = value->as(); + // if (arr != nullptr) { + // // dns: mangle name + + // code_ << " " << "" << key << ";" << std::endl; + // WriteComma(); + // code_ << "{"; + // if (arr->size() > 0) { + // is_first_item_ = true; + // for (ObjectRef o : *arr) { + // // todo might have to switch on object type. + // WriteComma(); + // ReflectionVTable::Global()->VisitAttrs(o.get(), this); + // } + // } + // code_ << "}"; + // return; + // } + // } + + void Visit(const char* key, ObjectRef* value) final { + auto metadata = Downcast(*value); + auto arr = metadata.as(); + if (arr != nullptr) { + VisitArray(key, arr); + return; + } + + auto old_is_first_item = is_first_item_; + is_first_item_ = true; + code_ << "{"; + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + code_ << "}"; + is_first_item_ = old_is_first_item; + } + + std::string GetOutput() { + return code_.str(); + } + + private: + ::std::stringstream code_; + bool is_first_item_; +}; + + +static std::string address_from_parts(const std::vector& parts) { + std::stringstream ss; + for (unsigned int i = 0; i < parts.size(); ++i) { + if (i > 0) { + ss << "_"; + } + ss << parts[i]; + } + return ss.str(); +} + +class MetadataQueuer : public AttrVisitor { + public: + using QueueItem = std::tuple; + MetadataQueuer(std::vector* queue) : queue_{queue} {} + + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, DataType* value) final {} + void Visit(const char* key, runtime::NDArray* value) final {} + void Visit(const char* key, void** value) final {} + + void Visit(const char* key, ObjectRef* value) final { + address_parts_.push_back(key); + if (value->as() != nullptr) { + auto metadata = Downcast(*value); + const runtime::metadata::MetadataArrayNode* arr = value->as(); + std::cout << "Is array? " << arr << std::endl; + if (arr != nullptr) { + for (unsigned int i = 0; i < arr->array.size(); i++) { + ObjectRef o = arr->array[i]; + std::cout << "queue-visiting array element " << i << ": " << o->type_index() << " (" << o.operator->() << ")" << std::endl; + if (o.as() != nullptr) { + std::stringstream ss; + ss << i; + address_parts_.push_back(ss.str()); + runtime::metadata::MetadataBase metadata = Downcast(o); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + address_parts_.pop_back(); + } + } + } else { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + } + + queue_->push_back(std::make_tuple(address_from_parts(address_parts_), Downcast(*value))); + } + address_parts_.pop_back(); + } + + private: + std::vector* queue_; + std::vector address_parts_; +}; + +class MetadataSerializer : public AttrVisitor { +public: + static constexpr const char* kGlobalSymbol = "kTvmgenMetadata"; + + MetadataSerializer() : is_first_item_{true} {} + + void WriteComma() { + if (is_first_item_) { + is_first_item_ = false; + } else { + code_ << ", " << std::endl; + } + } + + void WriteKey(const char* key) { + if (key != nullptr) { + code_ << " /* " << key << "*/"; + } + } + + void Visit(const char* key, double* value) final { + WriteComma(); + code_.setf(std::ios::hex | std::ios::showbase | std::ios::fixed | std::ios::scientific, + std::ios::basefield | std::ios::showbase | std::ios::floatfield); + code_ << *value; + WriteKey(key); + } + + void Visit(const char* key, int64_t* value) final { + WriteComma(); + code_ << *value << "L"; + WriteKey(key); + } + + void Visit(const char* key, uint64_t* value) final { + WriteComma(); + code_ << *value << "UL"; + WriteKey(key); + } + void Visit(const char* key, int* value) final { + WriteComma(); + code_ << *value; + WriteKey(key); + } + void Visit(const char* key, bool* value) final { + WriteComma(); + code_ << *value; + WriteKey(key); + } + void Visit(const char* key, std::string* value) final { + WriteComma(); + code_ << "\"" << *value << "\""; + WriteKey(key); + } + void Visit(const char* key, void** value) final { + WriteComma(); + code_ << *value; + WriteKey(key); + } + void Visit(const char* key, DataType* value) final { + WriteComma(); + code_ << "DLDataType{" << value->code() << ", " << value->bits() << ", " + << value->lanes() << "}"; + WriteKey(key); + } + + void Visit(const char* key, runtime::NDArray* value) final { + // TODO(areusch): probably we could consolidate --link-params here, tho... + ICHECK(false) << "do not support serializing NDArray as metadata"; + } + + void VisitArray(const runtime::metadata::MetadataArrayNode* array) { + std::cout << "visit array " << array << ": " << array->c_type << " " << array->array.size() << std::endl; + auto old_is_first_item = is_first_item_; + is_first_item_ = true; + for (unsigned int i = 0; i < array->array.size(); ++i) { //ObjectRef o : *(array->array)) { + ObjectRef o = array->array[i]; + std::cout << "visiting array element " << i << ": " << o->type_index() << " (" << o.operator->() << ")" << std::endl; + if (o->IsInstance()) { + int64_t i = Downcast(o); + Visit(nullptr, &i); + continue; + } + + if (o->IsInstance()) { + std::string s = Downcast(o); + Visit(nullptr, &s); + continue; + } + + runtime::metadata::MetadataBase metadata = Downcast(o); + std::cout << "visit member " << metadata->get_name() << std::endl; + std::stringstream i_str; + i_str << i; + address_.push_back(i_str.str()); + Visit(nullptr, &metadata); + address_.pop_back(); +// ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + } + is_first_item_ = old_is_first_item; + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = value->as(); + std::cout << "Is array? " << arr << std::endl; + if (arr != nullptr) { + WriteComma(); + if (key != nullptr) { + address_.push_back(key); + } + code_ << address_from_parts(address_) << " , " << arr->array.size() << " /* " << key << "_size */"; + if (key != nullptr) { + address_.pop_back(); + } +// VisitArray(key, Downcast(*value).operator->()); + // WriteComma(); + // code_ << "{"; + // if (arr->size() > 0) { + // is_first_item_ = true; + // for (ObjectRef* o : *arr) { + // // todo might have to switch on object type. + // WriteComma(); + // ReflectionVTable::Global()->VisitAttrs(o.get(), this); + // } + // } + // code_ << "}"; + return; + } + + std::cout << "downcast..." << std::endl; + runtime::metadata::MetadataBase metadata = Downcast(*value); + std::cout << "downcast ok: " << metadata->get_name() << std::endl; + + if (key != nullptr) { // NOTE: outermost call passes nullptr key + address_.push_back(key); + } + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + if (key != nullptr) { // NOTE: outermost call passes nullptr key + address_.pop_back(); + } + } + + // void EnterStruct(::tvm::runtime::metadata::MetadataBase metadata) { + // const char* type_key = metadata->GetTypeKey(); + // is_defining_struct_.emplace_back( + // !generated_struct_decls_.contains(type_key)); + // if (is_defining_struct()) { + // decl_ << "struct " << get_struct_name(metadata) << "{"; + // } + // is_first_item_.emplace_back(true); + // } + + // void ExitStruct(::tvm::runtime::metadata::MetadataBase metadata) { + // decl_ << "}; // struct " << get_struct_name(metadata); + // is_first_item_.pop_back(); + // } + + void CodegenMetadata(::tvm::runtime::metadata::Metadata metadata) { + decl_ + << "#include " << std::endl + << "#include " << std::endl + << "#include " << std::endl; + std::vector queue; + MetadataQueuer queuer{&queue}; + queuer.Visit(kGlobalSymbol, &metadata); + + for (MetadataQueuer::QueueItem item : queue) { + auto struct_name = std::get<0>(item); + auto obj = std::get<1>(item); + auto arr = obj.as(); + std::cout << "codegen: " << struct_name; + is_first_item_ = true; + address_.push_back(struct_name); + if (arr != nullptr) { + const char* const_part = "const "; + if (strcmp(arr->c_type, "const char*") == 0) { + const_part = ""; + } + code_ << const_part << arr->c_type << " " << struct_name + << "[" << arr->array.size() << "] = {" << std::endl; + VisitArray(arr); + } else { + code_ << "const struct TVMMetadata " << struct_name << " = {" << std::endl; + Visit(nullptr, &obj); + } + address_.pop_back(); + code_ << "};" << std::endl; + } + } + + std::string GetOutput() { + return decl_.str() + code_.str(); + } + +private: + std::vector address_; + std::stringstream decl_; + std::stringstream code_; + bool is_first_item_; + std::unordered_set generated_struct_decls_; + std::vector is_defining_struct_; +}; + runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, - relay::Runtime runtime, runtime::Metadata metadata) { + relay::Runtime runtime, runtime::metadata::Metadata metadata) { Array func_names; for (runtime::Module mod : modules) { auto pf_funcs = mod.GetFunction("get_func_names"); @@ -358,6 +786,30 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array& mod return std::move(csrc_metadata_module); } +runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metadata) { +// MetadataStructDefiner definer; +// ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), &definer); + MetadataSerializer serializer; + serializer.CodegenMetadata(metadata); + std::stringstream lookup_func; + lookup_func << "#ifdef __cplusplus\n" + << "extern \"C\"\n" + << "#endif\n"; + + lookup_func << "TVM_DLL int32_t get_c_metadata(TVMValue* arg_values, int* arg_tcodes, int num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {" << std::endl; + lookup_func << " ret_values[0].v_handle = (void*) &" << MetadataSerializer::kGlobalSymbol << ";" << std::endl; + lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl; + lookup_func << " return 0;" << std::endl; + lookup_func << "};" << std::endl; + + auto mod = MetadataModuleCreate(metadata); + std::vector func_names{"get_c_metadata"}; + //definer.GetOutput() + + auto c = CSourceModuleCreate(serializer.GetOutput() + lookup_func.str(), "c", func_names, Array()); + mod->Import(c); + return mod; +} + // supports limited save without cross compile class DeviceSourceModuleNode final : public runtime::ModuleNode { public: @@ -423,7 +875,7 @@ TVM_REGISTER_GLOBAL("runtime.CreateCSourceCrtMetadataModule") .set_body_typed([](const Array& modules, Target target, relay::Runtime runtime) { // Note that we don't need metadata when we compile a single operator - return CreateCSourceCrtMetadataModule(modules, target, runtime, runtime::Metadata()); + return CreateCSourceCrtMetadataModule(modules, target, runtime, runtime::metadata::Metadata()); }); } // namespace codegen diff --git a/src/target/source/source_module.h b/src/target/source/source_module.h index fde363c1198a..c7d3302e64b4 100644 --- a/src/target/source/source_module.h +++ b/src/target/source/source_module.h @@ -26,24 +26,31 @@ #define TVM_TARGET_SOURCE_SOURCE_MODULE_H_ #include +#include #include #include -#include "../../runtime/meta_data.h" namespace tvm { namespace codegen { /*! + * \brief Wrap the submodules that are to be wrapped in a c-source metadata module for C runtime. * \param modules The modules to be wrapped. * \param target the target the modules are compiled for. * \param runtime the runtime to code generate against - * \param metadata the metadata needed for code generation. + * \param metadata Compiler-generated metadata exported to runtime. * \return The wrapped module. */ runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, - relay::Runtime runtime, runtime::Metadata metadata); + relay::Runtime runtime, runtime::metadata::Metadata metadata); + +/*! + * \brief Create C++-runtime targeted metadata module for "c" backend. + * \param metadata Compiler-generated metadata. + */ +runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metadata); } // namespace codegen } // namespace tvm From acb56c8314137475ab3dc489bfbb33252ca8ed4b Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 8 Dec 2021 13:50:40 -0800 Subject: [PATCH 08/41] Remove old Metadata --- src/runtime/meta_data.h | 47 ----------------------------------------- 1 file changed, 47 deletions(-) diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 228555ebafda..c19229243f1f 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -49,53 +49,6 @@ inline String get_name_mangled(const String& module_name, const String& name) { return ss.str(); } -/*! - * \brief Structure that can be optionally used by the executor codegen - */ -class MetadataNode : public Object { - public: - /*! \brief input information for the main function */ - Array inputs; - /*! \brief number of outputs of the main function */ - int num_outputs = 1; - /*! \brief device contexts information for the main function */ - Array devices; - /*! \brief the executor to be used to run the model */ - String executor = kTvmExecutorGraph; - /*! \brief The external API (packed or c) in use */ - String interface_api; - /*! \brief The internal API (packed or unpacked) in use */ - bool unpacked_api; - - String mod_name = ""; - - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "MetadataObj"; - TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, Object); -}; - -/*! - * \brief Managed reference to MetadataNode. - */ -class Metadata : public ObjectRef { - public: - TVM_DLL Metadata(Array inputs, Array devices, int num_outputs, String executor, - String mod_name, String interface_api = "packed", bool unpacked_api = false) { - auto n = make_object(); - n->inputs = inputs; - n->devices = devices; - n->num_outputs = num_outputs; - n->executor = executor; - n->interface_api = interface_api; - n->unpacked_api = unpacked_api; - n->mod_name = mod_name; - data_ = std::move(n); - } - - TVM_DEFINE_OBJECT_REF_METHODS(Metadata, ObjectRef, MetadataNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(MetadataNode); -}; - /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; From d7d05188f2b046aaa9a45fe62fdc2cf329777ac5 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 8 Dec 2021 13:50:53 -0800 Subject: [PATCH 09/41] compilation fixes in codegen? --- src/relay/backend/utils.h | 1 + src/target/metadata.h | 38 +++++++++++++++--------------- src/target/metadata_module.cc | 2 +- src/target/metadata_module.h | 3 ++- src/target/source/source_module.cc | 8 +++---- 5 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 8f430d292e07..07c829fbdf62 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -43,6 +43,7 @@ #include #include "../../runtime/meta_data.h" +#include "../../target/metadata.h" namespace tvm { namespace relay { diff --git a/src/target/metadata.h b/src/target/metadata.h index 1eaaeab8e59f..0cdf68768dc4 100644 --- a/src/target/metadata.h +++ b/src/target/metadata.h @@ -45,7 +45,7 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { auto inputs_accessor = inputs(); inputs_array.reserve(num_inputs()); for (int64_t i = 0; i < num_inputs(); ++i) { - inputs_array.push_back(TensorInfo{inputs_accessor[i]}); + inputs_array.push_back(::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{inputs_array, "struct TVMTensorInfo"}; v->Visit("inputs", &inputs_metadata_array); @@ -53,7 +53,7 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { auto outputs_accessor = outputs(); outputs_array.reserve(num_outputs()); for (int64_t i = 0; i < num_outputs(); ++i) { - outputs_array.push_back(TensorInfo{outputs_accessor[i]}); + outputs_array.push_back(::tvm::runtime::metadata::TensorInfo{outputs_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{outputs_array, "struct TVMTensorInfo"}; v->Visit("outputs", &outputs_metadata_array); @@ -90,14 +90,15 @@ class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNo ) {} InMemoryMetadataNode( int64_t version, - const ::std::vector& inputs, - const ::std::vector& outputs, + const ::std::vector<::tvm::runtime::metadata::TensorInfo>& inputs, + const ::std::vector<::tvm::runtime::metadata::TensorInfo>& outputs, const ::std::vector<::std::string>& devices, const ::tvm::runtime::String executor, const ::tvm::runtime::String mod_name, const ::tvm::runtime::String interface_api, bool use_unpacked_api ) : + VisitableMetadataNode{&storage_}, inputs_{new struct TVMTensorInfo[inputs.size()]()}, inputs_objs_{inputs}, outputs_{new struct TVMTensorInfo[outputs.size()]()}, @@ -108,37 +109,36 @@ class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNo interface_api_{interface_api}, storage_{ version, - NULL, NULL, - NULL, NULL, - NULL, NULL, + nullptr, 0, + nullptr, 0, + nullptr, 0, executor_.c_str(), mod_name_.c_str(), interface_api_.c_str(), use_unpacked_api - }, - VisitableMetadataNode{&storage_} { + } { storage_.inputs = inputs_.get(); storage_.num_inputs = inputs.size(); - for (int i = 0; i < inputs.size(); ++i) { + for (unsigned int i = 0; i < inputs.size(); ++i) { inputs_.get()[i] = *inputs[i]->data(); } storage_.outputs = outputs_.get(); storage_.num_outputs = outputs.size(); - for (int i = 0; i < outputs.size(); ++i) { + for (unsigned int i = 0; i < outputs.size(); ++i) { outputs_.get()[i] = *outputs[i]->data(); } storage_.devices = devices_.get(); storage_.num_devices = devices.size(); - for (int i = 0; i < devices.size(); ++i) { + for (unsigned int i = 0; i < devices.size(); ++i) { devices_.get()[i] = devices[i].c_str(); } } private: ::std::unique_ptr inputs_; - std::vector inputs_objs_; + std::vector<::tvm::runtime::metadata::TensorInfo> inputs_objs_; ::std::unique_ptr outputs_; - std::vector outputs_objs_; + std::vector<::tvm::runtime::metadata::TensorInfo> outputs_objs_; ::std::unique_ptr devices_; ::std::string executor_; ::std::string mod_name_; @@ -158,7 +158,7 @@ class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode auto shape_accessor = shape(); shape_array.reserve(num_shape()); for (int64_t i = 0; i < num_shape(); ++i) { - shape_array.push_back(::tvm::Integer{shape_accessor[i]}); + shape_array.push_back(::tvm::Integer{int(shape_accessor[i])}); } ::tvm::runtime::metadata::MetadataArray shape_metadata_array{shape_array, "int64_t"}; v->Visit("shape", &shape_metadata_array); @@ -179,17 +179,17 @@ class InMemoryTensorInfoNode : public ::tvm::target::metadata::VisitableTensorIn const ::std::vector& shape, ::tvm::runtime::DataType dtype ) : + VisitableTensorInfoNode{&storage_}, name_{name}, shape_{new int64_t[shape.size()]()}, storage_{ name_.c_str(), - NULL, NULL, + nullptr, 0, dtype - }, - VisitableTensorInfoNode{&storage_} { + } { storage_.shape = shape_.get(); storage_.num_shape = shape.size(); - for (int i = 0; i < shape.size(); ++i) { + for (unsigned int i = 0; i < shape.size(); ++i) { shape_.get()[i] = shape[i]; } } diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 2aa4fe5e234d..a08ec453cd6c 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -37,7 +37,7 @@ namespace codegen { runtime::Module CreateMetadataModule( const std::unordered_map& const_var_ndarray, tvm::runtime::Module target_module, const Array& ext_modules, Target target, - tvm::relay::Runtime runtime, runtime::Metadata metadata) { + tvm::relay::Runtime runtime, runtime::metadata::Metadata metadata) { // Here we split modules into two groups: // 1. Those modules which can be exported to C-runtime. These are DSO-exportable // (i.e. llvm or c) modules which return nothing from get_const_vars(). diff --git a/src/target/metadata_module.h b/src/target/metadata_module.h index ee6f7231b3a1..9e0a25bb2421 100644 --- a/src/target/metadata_module.h +++ b/src/target/metadata_module.h @@ -26,6 +26,7 @@ #define TVM_TARGET_METADATA_MODULE_H_ #include +#include #include #include #include @@ -54,7 +55,7 @@ namespace codegen { runtime::Module CreateMetadataModule( const std::unordered_map& params, runtime::Module target_module, const Array& ext_modules, Target target, tvm::relay::Runtime runtime, - runtime::Metadata metadata); + runtime::metadata::Metadata metadata); } // namespace codegen } // namespace tvm diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index b943d7cffba0..8d8d8328ffa8 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -213,13 +213,13 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle) {\n"; code_ << "return " << run_func << "("; - for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { + for (unsigned int i = 0; i < metadata_->num_inputs(); ++i) { code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; } - for (int i = 0; i < metadata_->num_outputs; ++i) { - int j = metadata_->inputs.size() + i; + for (int i = 0; i < metadata_->num_outputs(); ++i) { + int j = metadata_->num_inputs() + i; code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data"; - if (i + 1 != metadata_->num_outputs) { + if (i + 1 != metadata_->num_outputs()) { code_ << ","; } } From a01947ab96caac87b0ce5b169aa8c101b2a9ad61 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 8 Dec 2021 15:03:23 -0800 Subject: [PATCH 10/41] replace MetadataModuleCreate --- src/target/metadata_module.cc | 129 +++++++++++++++++++++++----------- 1 file changed, 89 insertions(+), 40 deletions(-) diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index a08ec453cd6c..bed78a5c2e2c 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -27,6 +27,7 @@ #include +#include "../runtime/const_loader_module.h" #include "../runtime/meta_data.h" #include "llvm/llvm_module.h" #include "source/source_module.h" @@ -34,6 +35,88 @@ namespace tvm { namespace codegen { + +static runtime::Module CreateCrtMetadataModule(runtime::Module target_module, Target target, + relay::Runtime runtime, + runtime::metadata::Metadata metadata, + Array non_crt_exportable_modules, + Array crt_exportable_modules, + const std::unordered_map& const_var_ndarray) { + + if (!non_crt_exportable_modules.empty()) { + std::string non_exportable_modules; + for (unsigned int i = 0; i < non_crt_exportable_modules.size(); i++) { + if (i > 0) { + non_exportable_modules += ", "; + } + auto mod = non_crt_exportable_modules[i]; + auto pf_sym = mod.GetFunction("get_symbol"); + if (pf_sym != nullptr) { + non_exportable_modules += pf_sym().operator std::string(); + } else { + non_exportable_modules += + std::string{"(module type_key="} + mod->type_key() + std::string{")"}; + } + } + CHECK(false) << "These " << non_crt_exportable_modules.size() + << " modules are not exportable to C-runtime: " << non_exportable_modules; + } + + if (target->kind->name == "c") { + crt_exportable_modules.push_back(target_module); + target_module = CreateCSourceCrtMetadataModule(crt_exportable_modules, target, runtime, metadata); + } else if (target->kind->name == "llvm") { +#ifdef TVM_LLVM_VERSION + crt_exportable_modules.push_back(target_module); + target_module = CreateLLVMCrtMetadataModule(crt_exportable_modules, target, runtime); +#else // TVM_LLVM_VERSION + LOG(FATAL) << "TVM was not built with LLVM enabled."; +#endif // TVM_LLVM_VERSION + } + + return target_module; +} + +static runtime::Module CreateCppMetadataModule( + runtime::Module target_module, Target target, relay::Runtime runtime, + runtime::metadata::Metadata metadata, + const std::unordered_map>& const_vars_by_symbol, + Array non_crt_exportable_modules, + Array crt_exportable_modules, + const std::unordered_map& const_var_ndarray) { + if (!non_crt_exportable_modules.empty()) { + runtime::Module const_loader_mod = runtime::ConstLoaderModuleCreate(const_var_ndarray, const_vars_by_symbol); + const_loader_mod.Import(target_module); + for (const auto& it : non_crt_exportable_modules) { + const_loader_mod.Import(it); + } + target_module = const_loader_mod; + } + + if (metadata->executor() == runtime::kTvmExecutorAot && runtime->name == relay::kTvmRuntimeCpp) { + if (target->kind->name == "c") { + auto metadata_module = CreateCSourceCppMetadataModule(metadata); + metadata_module->Import(target_module); + target_module = metadata_module; + } else { + CHECK(false) << "Don't know how to create MetadataModule for target type " << target->str(); + } + } + + return target_module; +} + +/*! + * \brief Create a metadata module wrapper. The helper is used by different + * codegens, such as graph executor codegen and the vm compiler. + * + * \param params The metadata for initialization of all modules. + * \param target_module the internal module that is compiled by tvm. + * \param ext_modules The external modules that needs to be imported inside the metadata + * module(s). + * \param target The target that all the modules are compiled for + * \return The created metadata module that manages initialization of metadata. + */ runtime::Module CreateMetadataModule( const std::unordered_map& const_var_ndarray, tvm::runtime::Module target_module, const Array& ext_modules, Target target, @@ -83,49 +166,15 @@ runtime::Module CreateMetadataModule( } if (is_targeting_crt) { - if (!non_crt_exportable_modules.empty()) { - std::string non_exportable_modules; - for (unsigned int i = 0; i < non_crt_exportable_modules.size(); i++) { - if (i > 0) { - non_exportable_modules += ", "; - } - auto mod = non_crt_exportable_modules[i]; - auto pf_sym = mod.GetFunction("get_symbol"); - if (pf_sym != nullptr) { - non_exportable_modules += pf_sym().operator std::string(); - } else { - non_exportable_modules += - std::string{"(module type_key="} + mod->type_key() + std::string{")"}; - } - } - CHECK(false) << "These " << non_crt_exportable_modules.size() - << " modules are not exportable to C-runtime: " << non_exportable_modules; - } - - if (target->kind->name == "c") { - crt_exportable_modules.push_back(target_module); - target_module = - CreateCSourceCrtMetadataModule(crt_exportable_modules, target, runtime, metadata); - } else if (target->kind->name == "llvm") { -#ifdef TVM_LLVM_VERSION - crt_exportable_modules.push_back(target_module); - target_module = CreateLLVMCrtMetadataModule(crt_exportable_modules, target, runtime); -#else // TVM_LLVM_VERSION - LOG(FATAL) << "TVM was not built with LLVM enabled."; -#endif // TVM_LLVM_VERSION - } + return CreateCrtMetadataModule(target_module, target, runtime, metadata, non_crt_exportable_modules, crt_exportable_modules, const_var_ndarray); } else { - if (!non_crt_exportable_modules.empty()) { - runtime::Module binary_const_loader_mod = runtime::ConstLoaderModuleCreate(const_var_ndarray, const_vars_by_symbol); - binary_const_loader_mod.Import(target_module); - for (const auto& it : non_crt_exportable_modules) { - binary_const_loader_mod.Import(it); - } - return binary_const_loader_mod; - } + return CreateCppMetadataModule(target_module, target, runtime, metadata, const_vars_by_symbol, + non_crt_exportable_modules, crt_exportable_modules, + const_var_ndarray); } - return target_module; } + } // namespace codegen + } // namespace tvm From 1950d8b73f4bf6ede37dfe2d728f50d032a80422 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 8 Dec 2021 15:18:28 -0800 Subject: [PATCH 11/41] Add a runtime Module to mux between .text Metadata and live Metadata. --- src/runtime/meta_data.h | 11 +++++++ src/runtime/metadata.cc | 73 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index c19229243f1f..bb833fd24823 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -36,6 +36,8 @@ #include #include +#include + #include "runtime_base.h" namespace tvm { @@ -49,6 +51,15 @@ inline String get_name_mangled(const String& module_name, const String& name) { return ss.str(); } +/*! + * \brief Create a metadata module object. + * + * \param metadata Exported metadata structure. + * + * \return The created metadata module. + */ +Module MetadataModuleCreate(metadata::Metadata metadata); + /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index a08e30333a52..a6aecafea7a0 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -18,16 +18,27 @@ */ /*! - * \file metadata.cc - * \brief Implementations of the runtime component of Metadata. + * \file tvm/runtime/metadata.h + * \brief Defines implementations of TVM metadata which can exist in the runtime. */ +#include +#include +#include #include +#include namespace tvm { namespace runtime { namespace metadata { +MetadataArray::MetadataArray(Array array, const char* c_type) : MetadataBase{make_object(array, c_type)} {} + +std::string MetadataArrayNode::get_name() { return "MetadataArray"; } + +TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); +TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode); + ArrayAccessor MetadataNode::inputs() { if (inputs_refs_.get() == nullptr) { inputs_refs_.reset(new ::std::vector()); } return ArrayAccessor(data_->inputs, data_->num_inputs, inputs_refs_); @@ -47,8 +58,64 @@ TVM_REGISTER_OBJECT_TYPE(MetadataNode); TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) : MetadataBase{make_object(data)} {} std::string TensorInfoNode::get_name() { return std::string{"TensorInfo"}; } -TVM_REGISTER_OBJECT_TYPE(TensorInfoNode); } // namespace metadata + +class MetadataModuleNode : public ::tvm::runtime::ModuleNode { + public: + MetadataModuleNode(runtime::metadata::Metadata metadata) { + // CHECK((metadata.defined() && code.size() > 0) || (!metadata.defined() && code.size() == 0)) + // << "metadata and code must both be either defined (when passed from compiler) or undefined " + // << "(when passed from runtime)"; + metadata_ = metadata; +// code_ = code; + } + + const char* type_key() const { return "metadata_module"; } + + static Module LoadFromBinary() { + return Module(make_object(runtime::metadata::Metadata())); + } + + void SaveToBinary(dmlc::Stream* stream) final {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "get_metadata") { + return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { + if (!metadata_.defined()) { + TVMFunctionHandle f_handle; + int32_t ret_code = TVMBackendGetFuncFromEnv(this, "get_c_metadata", &f_handle); + CHECK_EQ(ret_code, 0) << "Unable to locate get_c_metadata PackedFunc"; + + TVMValue ret_value; + int ret_type_code; + ret_code = TVMFuncCall(f_handle, nullptr, nullptr, 0, &ret_value, &ret_type_code); + CHECK_EQ(ret_code, 0) << "Invoking get_c_metadata: TVMFuncCall returned " << ret_code; + + CHECK_EQ(ret_type_code, kTVMOpaqueHandle) << "Expected kOpaqueHandle returned; got " << ret_type_code; + CHECK(ret_value.v_handle != nullptr) << "get_c_metadata returned nullptr"; + + metadata_ = runtime::metadata::Metadata(static_cast(ret_value.v_handle)); + } + + *rv = metadata_; + return; + }); + } + + return PackedFunc(); + } + + private: + runtime::metadata::Metadata metadata_; +}; + +Module MetadataModuleCreate(metadata::Metadata metadata) { + return Module(make_object(metadata)); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata_module") +.set_body([](TVMArgs args, TVMRetValue* rv) { *rv = MetadataModuleNode::LoadFromBinary(); }); + } // namespace runtime } // namespace tvm From 2e7123db3cfc20fa6e45a677507195b4d28f90f0 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 8 Dec 2021 16:33:51 -0800 Subject: [PATCH 12/41] Move launch_param to namespace --- src/runtime/meta_data.h | 4 ++++ src/runtime/thread_storage_scope.h | 3 ++- src/target/build_common.h | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index bb833fd24823..c66feeca0634 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -60,9 +60,13 @@ inline String get_name_mangled(const String& module_name, const String& name) { */ Module MetadataModuleCreate(metadata::Metadata metadata); +namespace launch_param { + /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; +} + /*! \brief function information needed by device */ struct FunctionInfo { std::string name; diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index d577770db1a9..4122f9d0798e 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -24,6 +24,7 @@ #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ +#include #include #include @@ -205,7 +206,7 @@ class LaunchParamConfig { std::vector filled(6, false); for (size_t i = 0; i < launch_param_tags.size(); ++i) { const std::string& tag = launch_param_tags[i]; - if (tag == kUseDynamicSharedMemoryTag) { + if (tag == launch_param::kUseDynamicSharedMemoryTag) { ICHECK_EQ(i, launch_param_tags.size() - 1) << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; use_dyn_shared_memory_ = true; diff --git a/src/target/build_common.h b/src/target/build_common.h index c66c2b52822e..6c94ec8703b7 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -58,7 +58,7 @@ inline std::unordered_map ExtractFuncInfo(co } if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { if (opt.value()) { - info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag); + info.launch_param_tags.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); From 55dfb233e7cd74c0fb3711b074fb52022c8a3ac7 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 9 Dec 2021 12:08:17 -0800 Subject: [PATCH 13/41] Add test of c++ AOT. --- tests/python/relay/aot/test_cpp_aot.py | 94 ++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tests/python/relay/aot/test_cpp_aot.py diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py new file mode 100644 index 000000000000..faee1d75b52e --- /dev/null +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import sys +import textwrap + +import numpy as np +import pytest + +import tvm +from tvm import relay, TVMError +from tvm.ir.module import IRModule +from tvm.relay import backend, testing, transform +from tvm.relay.testing import byoc +from tvm.relay.op.annotation import compiler_begin, compiler_end +from aot_test_utils import ( + AOTTestModel, + AOT_DEFAULT_RUNNER, + generate_ref_data, + convert_to_relay, + compile_and_run, + compile_models, + parametrize_aot_options, +) + + +def print_mod_tree(m, indent=0): + print(f"{' ' * indent} - {m!r}") + for i in m.imported_modules: + print_mod_tree(i, indent + 2) + +def test_conv2d(): + RELAY_MODEL = textwrap.dedent("""\ + #[version = "0.0.5"] + def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { + %1 = nn.conv2d( + %data, + %weight, + padding=[2, 2], + channels=8, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %1 + } + """) + ir_mod = tvm.parser.fromtext(RELAY_MODEL) + + main_func = ir_mod["main"] + shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} + type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} + + weight_data = np.ones(shape_dict["weight"]).astype(type_dict["weight"]) + input_data = np.ones(shape_dict["data"]).astype(type_dict["data"]) + + params = {"weight": weight_data} + inputs = {"data": input_data} + output_list = generate_ref_data(ir_mod, inputs, params) + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build(ir_mod, params=params, target="c", + executor=backend.Executor("aot", {"interface-api": "c"})) + + print_mod_tree(mod.module) + + with tvm.contrib.utils.TempDirectory.set_keep_for_debug(): + mod.export_library("test.so") + mod.export_library("test.tar") + runner = tvm.runtime.load_module("test.so") + print_mod_tree(runner) + runner = tvm.runtime.executor.AotModule(runner["default"](tvm.cpu(0))) + runner.set_input(**inputs) + runner.run() + assert (runner.get_output(0).asnumpy() == output_list[0]).all() + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From bd257085b3000b7abe2ab63f22a9a8fb26e6b235 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 9 Dec 2021 16:54:16 -0800 Subject: [PATCH 14/41] Fix c++ lint and formatting. --- include/tvm/runtime/metadata.h | 26 +++- include/tvm/runtime/metadata_base.h | 41 +++--- include/tvm/support/span.h | 30 ++-- src/relay/backend/aot_executor_codegen.cc | 62 ++++----- src/relay/backend/build_module.cc | 4 +- src/relay/backend/graph_executor_codegen.cc | 1 - src/relay/backend/utils.cc | 4 +- src/relay/backend/vm/compiler.cc | 5 +- src/runtime/aot_executor/aot_executor.cc | 42 +++--- src/runtime/aot_executor/aot_executor.h | 7 +- .../aot_executor/aot_executor_factory.cc | 35 +++-- .../aot_executor/aot_executor_factory.h | 2 +- src/runtime/const_loader_module.cc | 5 +- src/runtime/const_loader_module.h | 3 +- src/runtime/meta_data.h | 5 +- src/runtime/metadata.cc | 50 ++++--- src/target/metadata.cc | 19 +-- src/target/metadata.h | 118 +++++++--------- src/target/metadata_module.cc | 30 ++-- src/target/source/source_module.cc | 131 +++++++++--------- src/target/source/source_module.h | 3 +- 21 files changed, 307 insertions(+), 316 deletions(-) diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h index c4911a179bb0..438698373f31 100644 --- a/include/tvm/runtime/metadata.h +++ b/include/tvm/runtime/metadata.h @@ -17,6 +17,9 @@ * under the License. */ +// NOTE: This file is intended to be compileable in C++ and C build processes. +// NOLINT(build/include_order) + /*! * \file tvm/runtime/metadata.h * \brief Defines types which can be used in Metadata. @@ -24,6 +27,10 @@ #ifndef TVM_RUNTIME_METADATA_H_ #define TVM_RUNTIME_METADATA_H_ +#include +#include +#include + #include #include #include @@ -67,7 +74,7 @@ class TensorInfo; class MetadataNode : public MetadataBaseNode { public: - MetadataNode(const struct ::TVMMetadata* data) : data_{data} {} + explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {} static constexpr const char* _type_key = "metadata.MetadataNode"; std::string get_name() override; inline int64_t version() const { return int64_t(data_->version); } @@ -79,10 +86,13 @@ class MetadataNode : public MetadataBaseNode { ArrayAccessor devices(); inline ::tvm::runtime::String executor() const { return ::tvm::runtime::String(data_->executor); } inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); } - inline ::tvm::runtime::String interface_api() const { return ::tvm::runtime::String(data_->interface_api); } - inline bool use_unpacked_api() const { return bool(data_->use_unpacked_api); } + inline ::tvm::runtime::String interface_api() const { + return ::tvm::runtime::String(data_->interface_api); + } + inline bool use_unpacked_api() const { return static_cast(data_->use_unpacked_api); } const struct ::TVMMetadata* data() const { return data_; } TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode); + private: const struct ::TVMMetadata* data_; ::std::shared_ptr<::std::vector> inputs_refs_; @@ -92,30 +102,32 @@ class MetadataNode : public MetadataBaseNode { class Metadata : public MetadataBase { public: - Metadata(const struct ::TVMMetadata* data); + explicit Metadata(const struct ::TVMMetadata* data); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode); }; class TensorInfoNode : public MetadataBaseNode { public: - TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {} + explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {} static constexpr const char* _type_key = "metadata.TensorInfoNode"; std::string get_name() override; inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); } inline int64_t num_shape() const { return data_->num_shape; } inline ::tvm::support::Span shape() const { - return ::tvm::support::Span(data_->shape, data_->shape + data_->num_shape); + return ::tvm::support::Span(data_->shape, + data_->shape + data_->num_shape); } inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); } const struct ::TVMTensorInfo* data() const { return data_; } TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode); + private: const struct ::TVMTensorInfo* data_; }; class TensorInfo : public MetadataBase { public: - TensorInfo(const struct ::TVMTensorInfo* data); + explicit TensorInfo(const struct ::TVMTensorInfo* data); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode); }; diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h index 4386818ec298..b155727ea404 100644 --- a/include/tvm/runtime/metadata_base.h +++ b/include/tvm/runtime/metadata_base.h @@ -24,10 +24,14 @@ #ifndef TVM_RUNTIME_METADATA_BASE_H_ #define TVM_RUNTIME_METADATA_BASE_H_ -#include #include #include +#include +#include +#include +#include + namespace tvm { namespace runtime { namespace metadata { @@ -53,9 +57,7 @@ class ArrayIterator { public: ArrayIterator(size_t index, ArrayAccessor* parent) : index_{index}, parent_{parent} {} - inline Ref operator*() { - return (*parent_)[index_]; - } + inline Ref operator*() { return (*parent_)[index_]; } inline ArrayIterator& operator++() { if (index_ < parent_->size()) { @@ -69,9 +71,7 @@ class ArrayIterator { return parent_ == other.parent_ && index_ == other.index_; } - inline bool operator!=(const ArrayIterator& other) { - return !operator==(other); - } + inline bool operator!=(const ArrayIterator& other) { return !operator==(other); } private: size_t index_; @@ -81,9 +81,9 @@ class ArrayIterator { template class ArrayAccessor { public: - - template ::value>::type> - ArrayAccessor(const C* data, size_t num_data, ::std::shared_ptr<::std::vector> refs) : data_{data}, num_data_{num_data}, refs_{refs} {} + template ::value>::type> + ArrayAccessor(const C* data, size_t num_data, ::std::shared_ptr<::std::vector> refs) + : data_{data}, num_data_{num_data}, refs_{refs} {} inline size_t size() { return num_data_; } @@ -103,13 +103,9 @@ class ArrayAccessor { return (*refs_)[index]; } - inline ArrayIterator begin() { - return ArrayIterator{0, this}; - } + inline ArrayIterator begin() { return ArrayIterator{0, this}; } - inline ArrayIterator end() { - return ArrayIterator{num_data_, this}; - } + inline ArrayIterator end() { return ArrayIterator{num_data_, this}; } private: const C* data_; @@ -120,7 +116,9 @@ class ArrayAccessor { template <> class ArrayAccessor { public: - ArrayAccessor(const char** data, size_t num_data, ::std::shared_ptr> refs) : data_{data}, num_data_{num_data}, refs_{refs} {} + ArrayAccessor(const char** data, size_t num_data, + ::std::shared_ptr> refs) + : data_{data}, num_data_{num_data}, refs_{refs} {} inline size_t size() { return num_data_; } @@ -160,13 +158,14 @@ enum MetadataTypeIndex : uint8_t { kBool = 2, kString = 3, kHandle = 4, - }; class MetadataArrayNode : public MetadataBaseNode { public: -// MetadataArray(Array array, MetadataTypeIndex type_index) : array{array}, type_index{type_index} {} - MetadataArrayNode(Array array, const char* c_type) : array(std::move(array)), c_type{c_type} {} + // MetadataArray(Array array, MetadataTypeIndex type_index) : array{array}, + // type_index{type_index} {} + MetadataArrayNode(Array array, const char* c_type) + : array(std::move(array)), c_type{c_type} {} std::string get_name() override; @@ -178,7 +177,7 @@ class MetadataArrayNode : public MetadataBaseNode { class MetadataArray : public MetadataBase { public: -// MetadataArray(Array array, MetadataTypeIndex type_index); + // MetadataArray(Array array, MetadataTypeIndex type_index); MetadataArray(Array array, const char* c_type); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); }; diff --git a/include/tvm/support/span.h b/include/tvm/support/span.h index 36c86db6fd5e..4f2da9bf2c98 100644 --- a/include/tvm/support/span.h +++ b/include/tvm/support/span.h @@ -32,31 +32,23 @@ namespace tvm { namespace support { -template //, std::enable_if_t::value>::value> = true> +template class Span { public: class iterator : public std::iterator { public: - inline iterator(T* ptr, T* end) : ptr_{ptr}, end_{end} { - CHECK_GE(end, ptr); - } + inline iterator(T* ptr, T* end) : ptr_{ptr}, end_{end} { CHECK_GE(end, ptr); } - inline W operator*() { - return W(*ptr_); - } + inline W operator*() { return W(*ptr_); } inline iterator& operator++() { if (ptr_ != end_) ptr_++; return *this; } - inline bool operator==(iterator other) { - return ptr_ == other.ptr_ && end_ == other.end_; - } + inline bool operator==(iterator other) { return ptr_ == other.ptr_ && end_ == other.end_; } - inline bool operator!=(iterator other) { - return !(*this == other); - } + inline bool operator!=(iterator other) { return !(*this == other); } private: T* ptr_; @@ -66,13 +58,9 @@ class Span { inline Span(T* begin, int num_elements) : begin_{begin}, end_{begin + num_elements} {} inline Span(T* begin, T* end) : begin_{begin}, end_{end} {} - inline iterator begin() { - return iterator(begin_, end_); - } + inline iterator begin() { return iterator(begin_, end_); } - inline iterator end() { - return iterator(end_, end_); - } + inline iterator end() { return iterator(end_, end_); } inline W operator[](int i) { T* to_return = begin_ + i; @@ -80,9 +68,7 @@ class Span { return W(*to_return); } - inline operator std::vector() { - return std::vector(begin(), end()); - } + inline operator std::vector() { return std::vector(begin(), end()); } protected: T* begin_; diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 59ceab4fe4c3..145cee0f60ec 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -42,10 +42,10 @@ #include #include +#include "../../target/metadata.h" #include "../op/annotation/annotation.h" #include "../op/call/call.h" #include "../op/memory/device_copy.h" -#include "../../target/metadata.h" #include "../transforms/device_aware_visitors.h" #include "./name_transforms.h" #include "./te_compiler.h" @@ -338,27 +338,21 @@ class AOTExecutorCodegen : public MixedModeVisitor { return data; } } - return data; /*tvm::tir::Call( - DataType::Handle(), - tvm::tir::builtin::tvm_stack_make_array(), - Array({data, tvm::tir::Call(DataType::Handle(), - tvm::tir::builtin::tvm_stack_make_shape(), - {ttype->shape}), - tvm::Integer(0), - tvm::Integer(ttype->shape.size()), - tvm::tir::make_const(ttype->dtype, 0), - tvm::Integer(0)})); */ + return data; } void PushTuple(Tuple tuple, std::vector sids, Array args) { CHECK_EQ(sids.size(), tuple->fields.size()) - << "Relay tuple does not map 1:1 into TIR; AOT can't handle this type of Relay Expr in a CallNode."; + << "Relay tuple does not map 1:1 into TIR; AOT can't handle this type of Relay Expr in a " + "CallNode."; StorageInfo& sinfo = storage_device_map_[tuple]; for (unsigned int i = 0; i < sids.size(); ++i) { - if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[i]) != return_sid_.end()) { + if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[i]) != + return_sid_.end()) { args.push_back(sids[i]); } else { - args.push_back(MakeDLTensor(tuple->fields[i], Downcast(tuple->fields[i]->checked_type()), sids[i])); + args.push_back(MakeDLTensor( + tuple->fields[i], Downcast(tuple->fields[i]->checked_type()), sids[i])); } } } @@ -377,10 +371,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Pack the inputs for (const Expr& arg : call_lowered_props.arguments) { if (params_by_expr_.find(arg) != params_by_expr_.end()) { - args.push_back(MakeDLTensor(arg, Downcast(arg->checked_type()), - tir::Cast(runtime::DataType(DataType::TypeCode::kHandle, 32, 1), - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[arg])})))); + args.push_back(MakeDLTensor( + arg, Downcast(arg->checked_type()), + tir::Cast(runtime::DataType(DataType::TypeCode::kHandle, 32, 1), + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[arg])})))); } else { auto sids = FindExpr(arg); if (sids.size() > 1) { @@ -388,7 +383,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { PushTuple(tuple, sids, args); } else { StorageInfo& sinfo = storage_device_map_[arg]; - if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]) != return_sid_.end()) { + if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]) != + return_sid_.end()) { args.push_back(sids[0]); } else { args.push_back(MakeDLTensor(arg, Downcast(arg->checked_type()), sids[0])); @@ -404,10 +400,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { PushTuple(tuple, result_expr_sid, args); } else { StorageInfo& sinfo = storage_device_map_[result_expr]; - if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]) != return_sid_.end()) { + if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]) != + return_sid_.end()) { args.push_back(result_expr_sid[0]); } else { - args.push_back(MakeDLTensor(result_expr, Downcast(result_expr->checked_type()), result_expr_sid[0])); + args.push_back(MakeDLTensor(result_expr, Downcast(result_expr->checked_type()), + result_expr_sid[0])); } } @@ -802,7 +800,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { public: AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) - : mod_(mod), targets_(targets), target_host_(target_host), use_unpacked_api_(Bool(false)), + : mod_(mod), + targets_(targets), + target_host_(target_host), + use_unpacked_api_(Bool(false)), use_call_cpacked_(Bool(false)) {} LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { @@ -823,12 +824,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Validate choice of use_unpacked_api_ and use_call_cpacked_ if (runtime_config->name == kTvmRuntimeCrt) { - CHECK(interface_api == "c" || bool(use_unpacked_api_) == false) + CHECK(interface_api == "c" || use_unpacked_api_ == false) << "Either need interface_api == \"c\" (got: " << interface_api << ") or unpacked-api == false (got: " << use_unpacked_api_ << ") when targeting c runtime"; } else if (runtime_config->name == kTvmRuntimeCpp) { - CHECK(bool(use_unpacked_api_) == false && bool(use_call_cpacked_) == true) + CHECK(use_unpacked_api_ == false && static_cast(use_call_cpacked_) == true) << "Need unpacked-api == false (got: " << use_unpacked_api_ << ") and interface-api == \"c\" (got: " << interface_api << ") when targeting c++ runtime"; @@ -971,9 +972,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { for (auto v : input_vars_) { auto ttype = Downcast(v->type_annotation); inputs.push_back( - runtime::metadata::TensorInfo( - make_object( - v->name_hint(), ShapeToJSON(ttype->shape), ttype->dtype))); + runtime::metadata::TensorInfo(make_object( + v->name_hint(), ShapeToJSON(ttype->shape), ttype->dtype))); } LOG(INFO) << "MAKE METADATA? "; @@ -984,9 +984,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::stringstream name; name << "output" << i; outputs.push_back( - runtime::metadata::TensorInfo( - make_object( - name.str(), ShapeToJSON(ttype->shape), ttype->dtype))); + runtime::metadata::TensorInfo(make_object( + name.str(), ShapeToJSON(ttype->shape), ttype->dtype))); } auto devices = ListDevices(); std::vector devices_vector; @@ -994,7 +993,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { devices_vector.push_back(d.operator std::string()); } auto n = make_object( - kMetadataVersion, inputs, outputs, devices_vector, runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_); + kMetadataVersion, inputs, outputs, devices_vector, runtime::kTvmExecutorAot, mod_name, + interface_api, use_unpacked_api_); ret.metadata = runtime::metadata::Metadata(std::move(n)); LOG(INFO) << "MAKE METADATA: " << ret.metadata; return ret; diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index c7e9ee779dbf..35dcd7378382 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -103,7 +103,9 @@ struct ExecutorCodegen { Array ListDevices() { return CallFunc>("get_devices"); } - runtime::metadata::Metadata GetMetadata() { return CallFunc("get_metadata"); } + runtime::metadata::Metadata GetMetadata() { + return CallFunc("get_metadata"); + } virtual ~ExecutorCodegen() {} protected: diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 59d3786951c8..f9e5f4eb4a5b 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -302,7 +302,6 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator ShapeToJSON(tvm::Array shape) { std::vector ret; for (IndexExpr dim : shape) { - const int64_t* pval = tir::as_const_int(dim); - ret.push_back(*pval); + const int64_t* pval = tir::as_const_int(dim); + ret.push_back(*pval); } return ret; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 6af7991f045d..1253e8527739 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1162,8 +1162,9 @@ void VMCompiler::Codegen() { lib = tvm::build(per_tvm_target_modules, config_->host_target); } - lib = codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target, - Runtime::Create("cpp"), runtime::metadata::Metadata(make_object())); + lib = codegen::CreateMetadataModule( + params_, lib, ext_mods, config_->host_target, Runtime::Create("cpp"), + runtime::metadata::Metadata(make_object())); exec_->SetLib(lib); } diff --git a/src/runtime/aot_executor/aot_executor.cc b/src/runtime/aot_executor/aot_executor.cc index cfa4c1e36ebd..763fb39c3b4b 100644 --- a/src/runtime/aot_executor/aot_executor.cc +++ b/src/runtime/aot_executor/aot_executor.cc @@ -25,14 +25,15 @@ #include "aot_executor.h" +#include + #include namespace tvm { namespace runtime { -AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector& devs) : - module_{module}, devices_{devs} { - +AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector& devs) + : module_{module}, devices_{devs} { auto fmetadata = module->GetFunction("get_metadata"); CHECK(fmetadata != nullptr) << "Expected a module with PackedFunc get_metadata"; auto ret_value = fmetadata(); @@ -40,15 +41,18 @@ AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector& for (auto input : metadata_->inputs()) { // TODO(areusch): Encode device information in Metadata. - args_.emplace_back(NDArray::Empty(ShapeTuple(input->shape().begin(), input->shape().end()), input->dtype(), devices_[0])); + args_.emplace_back(NDArray::Empty(ShapeTuple(input->shape().begin(), input->shape().end()), + input->dtype(), devices_[0])); } for (auto output : metadata_->outputs()) { - args_.emplace_back(NDArray::Empty(ShapeTuple(output->shape().begin(), output->shape().end()), output->dtype(), devices_[0])); + args_.emplace_back(NDArray::Empty(ShapeTuple(output->shape().begin(), output->shape().end()), + output->dtype(), devices_[0])); } } -PackedFunc AotExecutor::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { +PackedFunc AotExecutor::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -129,7 +133,7 @@ void AotExecutor::Run() { call_type_codes.get()[i] = kTVMDLTensorHandle; } - TVMArgs args{call_values, call_type_codes, num_args}; + TVMArgs args{call_values.get(), call_type_codes.get(), num_args}; TVMRetValue rv; pf.CallPacked(args, &rv); } @@ -154,9 +158,7 @@ int AotExecutor::GetOutputIndex(const std::string& name) { return -1; } -void AotExecutor::SetInput(int index, DLTensor* data_ref) { - args_[index].CopyFrom(data_ref); -} +void AotExecutor::SetInput(int index, DLTensor* data_ref) { args_[index].CopyFrom(data_ref); } void AotExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) { ICHECK(false) << "not implemented"; @@ -166,25 +168,15 @@ void AotExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) { ICHECK(false) << "not implemented"; } -int AotExecutor::NumOutputs() const { - return metadata_->num_outputs(); -} +int AotExecutor::NumOutputs() const { return metadata_->num_outputs(); } -int AotExecutor::NumInputs() const { - return metadata_->num_inputs(); -} +int AotExecutor::NumInputs() const { return metadata_->num_inputs(); } -NDArray AotExecutor::GetInput(int index) const { - return args_[index]; -} +NDArray AotExecutor::GetInput(int index) const { return args_[index]; } -NDArray AotExecutor::GetOutput(int index) const { - return args_[metadata_->num_inputs() + index]; -} +NDArray AotExecutor::GetOutput(int index) const { return args_[metadata_->num_inputs() + index]; } -void AotExecutor::CopyOutputTo(int index, DLTensor* data_out) { - GetOutput(index).CopyTo(data_out); -} +void AotExecutor::CopyOutputTo(int index, DLTensor* data_out) { GetOutput(index).CopyTo(data_out); } } // namespace runtime } // namespace tvm diff --git a/src/runtime/aot_executor/aot_executor.h b/src/runtime/aot_executor/aot_executor.h index 591af78284e4..a213ef83bdc2 100644 --- a/src/runtime/aot_executor/aot_executor.h +++ b/src/runtime/aot_executor/aot_executor.h @@ -25,19 +25,18 @@ #ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ #define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ -#include -#include #include #include -#include #include +#include +#include +#include namespace tvm { namespace runtime { class TVM_DLL AotExecutor : public ModuleNode { - public: /*! * \brief Implements member function lookup for this Module for the frontend. diff --git a/src/runtime/aot_executor/aot_executor_factory.cc b/src/runtime/aot_executor/aot_executor_factory.cc index e8ded8573028..4cb3026991fe 100644 --- a/src/runtime/aot_executor/aot_executor_factory.cc +++ b/src/runtime/aot_executor/aot_executor_factory.cc @@ -107,24 +107,23 @@ Module AotExecutorFactoryModuleLoadBinary(void* strm) { return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.aot_executor_factory.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - ICHECK_GE(args.num_args, 2) << "The expected number of arguments for " - "aot_executor_factory.create needs at least 2, " - "but it has " - << args.num_args; - // The argument order is module, module_name, param0_name, param0_tensor, - // [param1_name, param1_tensor], ... - ICHECK_EQ((args.size() - 2) % 2, 0); - std::unordered_map params; - for (size_t i = 2; i < static_cast(args.size()); i += 2) { - std::string name = args[i].operator String(); - params[name] = args[i + 1].operator tvm::runtime::NDArray(); - } - auto exec = make_object(params, args[1]); - exec->Import(args[0]); - *rv = Module(exec); - }); +TVM_REGISTER_GLOBAL("tvm.aot_executor_factory.create").set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.num_args, 2) << "The expected number of arguments for " + "aot_executor_factory.create needs at least 2, " + "but it has " + << args.num_args; + // The argument order is module, module_name, param0_name, param0_tensor, + // [param1_name, param1_tensor], ... + ICHECK_EQ((args.size() - 2) % 2, 0); + std::unordered_map params; + for (size_t i = 2; i < static_cast(args.size()); i += 2) { + std::string name = args[i].operator String(); + params[name] = args[i + 1].operator tvm::runtime::NDArray(); + } + auto exec = make_object(params, args[1]); + exec->Import(args[0]); + *rv = Module(exec); +}); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_AotExecutorFactory") .set_body_typed(AotExecutorFactoryModuleLoadBinary); diff --git a/src/runtime/aot_executor/aot_executor_factory.h b/src/runtime/aot_executor/aot_executor_factory.h index fbbebe1a4d86..1d6a0a62776e 100644 --- a/src/runtime/aot_executor/aot_executor_factory.h +++ b/src/runtime/aot_executor/aot_executor_factory.h @@ -116,4 +116,4 @@ class TVM_DLL AotExecutorFactory : public runtime::ModuleNode { } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_FACTORY_H_ +#endif // TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_ diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index fd890e40cd0b..818525def22c 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -47,8 +47,9 @@ namespace runtime { */ class ConstLoaderModuleNode : public ModuleNode { public: - ConstLoaderModuleNode(const std::unordered_map& const_var_ndarray, - const std::unordered_map>& const_vars_by_symbol) + ConstLoaderModuleNode( + const std::unordered_map& const_var_ndarray, + const std::unordered_map>& const_vars_by_symbol) : const_var_ndarray_(const_var_ndarray), const_vars_by_symbol_(const_vars_by_symbol) { // Only the related submodules are cached to reduce the number of runtime // symbol lookup for initialization. Otherwise, symbols/primitives in the diff --git a/src/runtime/const_loader_module.h b/src/runtime/const_loader_module.h index bd88f15c5bcc..eb548dfcf370 100644 --- a/src/runtime/const_loader_module.h +++ b/src/runtime/const_loader_module.h @@ -25,10 +25,11 @@ #ifndef TVM_RUNTIME_CONST_LOADER_MODULE_H_ #define TVM_RUNTIME_CONST_LOADER_MODULE_H_ +#include + #include #include #include -#include namespace tvm { namespace runtime { diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index c66feeca0634..766b93261ac0 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -36,8 +37,6 @@ #include #include -#include - #include "runtime_base.h" namespace tvm { @@ -65,7 +64,7 @@ namespace launch_param { /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; -} +} // namespace launch_param /*! \brief function information needed by device */ struct FunctionInfo { diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index a6aecafea7a0..2ba14ff84ef2 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -22,17 +22,19 @@ * \brief Defines implementations of TVM metadata which can exist in the runtime. */ -#include #include #include #include #include +#include + namespace tvm { namespace runtime { namespace metadata { -MetadataArray::MetadataArray(Array array, const char* c_type) : MetadataBase{make_object(array, c_type)} {} +MetadataArray::MetadataArray(Array array, const char* c_type) + : MetadataBase{make_object(array, c_type)} {} std::string MetadataArrayNode::get_name() { return "MetadataArray"; } @@ -40,35 +42,45 @@ TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode); ArrayAccessor MetadataNode::inputs() { - if (inputs_refs_.get() == nullptr) { inputs_refs_.reset(new ::std::vector()); } - return ArrayAccessor(data_->inputs, data_->num_inputs, inputs_refs_); + if (inputs_refs_.get() == nullptr) { + inputs_refs_.reset(new ::std::vector()); + } + return ArrayAccessor(data_->inputs, data_->num_inputs, + inputs_refs_); } ArrayAccessor MetadataNode::outputs() { - if (outputs_refs_.get() == nullptr) { outputs_refs_.reset(new ::std::vector()); } - return ArrayAccessor(data_->outputs, data_->num_outputs, outputs_refs_); + if (outputs_refs_.get() == nullptr) { + outputs_refs_.reset(new ::std::vector()); + } + return ArrayAccessor(data_->outputs, data_->num_outputs, + outputs_refs_); } ArrayAccessor MetadataNode::devices() { - if (devices_refs_.get() == nullptr) { devices_refs_.reset(new ::std::vector<::tvm::runtime::String>()); } - return ArrayAccessor(data_->devices, data_->num_devices, devices_refs_); + if (devices_refs_.get() == nullptr) { + devices_refs_.reset(new ::std::vector<::tvm::runtime::String>()); + } + return ArrayAccessor(data_->devices, data_->num_devices, + devices_refs_); } -Metadata::Metadata(const struct ::TVMMetadata* data) : - MetadataBase{make_object(data)} {} +Metadata::Metadata(const struct ::TVMMetadata* data) + : MetadataBase{make_object(data)} {} std::string MetadataNode::get_name() { return std::string{"Metadata"}; } TVM_REGISTER_OBJECT_TYPE(MetadataNode); -TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) : - MetadataBase{make_object(data)} {} +TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) + : MetadataBase{make_object(data)} {} std::string TensorInfoNode::get_name() { return std::string{"TensorInfo"}; } } // namespace metadata class MetadataModuleNode : public ::tvm::runtime::ModuleNode { public: - MetadataModuleNode(runtime::metadata::Metadata metadata) { + explicit MetadataModuleNode(runtime::metadata::Metadata metadata) { // CHECK((metadata.defined() && code.size() > 0) || (!metadata.defined() && code.size() == 0)) - // << "metadata and code must both be either defined (when passed from compiler) or undefined " + // << "metadata and code must both be either defined (when passed from compiler) or undefined + // " // << "(when passed from runtime)"; metadata_ = metadata; -// code_ = code; + // code_ = code; } const char* type_key() const { return "metadata_module"; } @@ -92,10 +104,12 @@ class MetadataModuleNode : public ::tvm::runtime::ModuleNode { ret_code = TVMFuncCall(f_handle, nullptr, nullptr, 0, &ret_value, &ret_type_code); CHECK_EQ(ret_code, 0) << "Invoking get_c_metadata: TVMFuncCall returned " << ret_code; - CHECK_EQ(ret_type_code, kTVMOpaqueHandle) << "Expected kOpaqueHandle returned; got " << ret_type_code; + CHECK_EQ(ret_type_code, kTVMOpaqueHandle) + << "Expected kOpaqueHandle returned; got " << ret_type_code; CHECK(ret_value.v_handle != nullptr) << "get_c_metadata returned nullptr"; - metadata_ = runtime::metadata::Metadata(static_cast(ret_value.v_handle)); + metadata_ = runtime::metadata::Metadata( + static_cast(ret_value.v_handle)); } *rv = metadata_; @@ -115,7 +129,7 @@ Module MetadataModuleCreate(metadata::Metadata metadata) { } TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata_module") -.set_body([](TVMArgs args, TVMRetValue* rv) { *rv = MetadataModuleNode::LoadFromBinary(); }); + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = MetadataModuleNode::LoadFromBinary(); }); } // namespace runtime } // namespace tvm diff --git a/src/target/metadata.cc b/src/target/metadata.cc index 193f63c5133a..adf4cba3e610 100644 --- a/src/target/metadata.cc +++ b/src/target/metadata.cc @@ -23,21 +23,24 @@ */ #include "metadata.h" + #include namespace tvm { namespace target { namespace metadata { -TVM_REGISTER_REFLECTION_VTABLE(VisitableMetadataNode, ::tvm::detail::ReflectionTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); -}); +TVM_REGISTER_REFLECTION_VTABLE(VisitableMetadataNode, + ::tvm::detail::ReflectionTrait) + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); -TVM_REGISTER_REFLECTION_VTABLE(VisitableTensorInfoNode, ::tvm::detail::ReflectionTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); -}); +TVM_REGISTER_REFLECTION_VTABLE(VisitableTensorInfoNode, + ::tvm::detail::ReflectionTrait) + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); } // namespace metadata } // namespace target diff --git a/src/target/metadata.h b/src/target/metadata.h index 0cdf68768dc4..64a9bbf1d562 100644 --- a/src/target/metadata.h +++ b/src/target/metadata.h @@ -21,13 +21,14 @@ * \file tvm/target/metadata.h * \brief Extends Metadata for use in the compiler. */ -#ifndef TVM_TARGET_METADATA_H -#define TVM_TARGET_METADATA_H +#ifndef TVM_TARGET_METADATA_H_ +#define TVM_TARGET_METADATA_H_ + +#include #include #include #include -#include namespace tvm { namespace target { @@ -36,7 +37,7 @@ namespace metadata { class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { public: explicit VisitableMetadataNode(const struct ::TVMMetadata* data) : MetadataNode{data} {} - explicit VisitableMetadataNode() : MetadataNode{nullptr} {} + VisitableMetadataNode() : MetadataNode{nullptr} {} void VisitAttrs(AttrVisitor* v) { int64_t version_cpp{version()}; @@ -47,7 +48,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { for (int64_t i = 0; i < num_inputs(); ++i) { inputs_array.push_back(::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]}); } - ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{inputs_array, "struct TVMTensorInfo"}; + ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{inputs_array, + "struct TVMTensorInfo"}; v->Visit("inputs", &inputs_metadata_array); auto outputs_array = Array(); auto outputs_accessor = outputs(); @@ -55,7 +57,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { for (int64_t i = 0; i < num_outputs(); ++i) { outputs_array.push_back(::tvm::runtime::metadata::TensorInfo{outputs_accessor[i]}); } - ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{outputs_array, "struct TVMTensorInfo"}; + ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{outputs_array, + "struct TVMTensorInfo"}; v->Visit("outputs", &outputs_metadata_array); auto devices_array = Array(); auto devices_accessor = devices(); @@ -78,45 +81,37 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNode { public: - InMemoryMetadataNode() : InMemoryMetadataNode( - 0 /* version */, - {} /* inputs */, - {} /* outputs */, - {} /* devices */, - "" /* executor */, - "" /* mod_name */, - "" /* interface_api */, - false /* use_unpacked_api */ - ) {} - InMemoryMetadataNode( - int64_t version, - const ::std::vector<::tvm::runtime::metadata::TensorInfo>& inputs, - const ::std::vector<::tvm::runtime::metadata::TensorInfo>& outputs, - const ::std::vector<::std::string>& devices, - const ::tvm::runtime::String executor, - const ::tvm::runtime::String mod_name, - const ::tvm::runtime::String interface_api, - bool use_unpacked_api - ) : - VisitableMetadataNode{&storage_}, - inputs_{new struct TVMTensorInfo[inputs.size()]()}, - inputs_objs_{inputs}, - outputs_{new struct TVMTensorInfo[outputs.size()]()}, - outputs_objs_{outputs}, - devices_{new const char*[devices.size()]()}, - executor_{executor}, - mod_name_{mod_name}, - interface_api_{interface_api}, - storage_{ - version, - nullptr, 0, - nullptr, 0, - nullptr, 0, - executor_.c_str(), - mod_name_.c_str(), - interface_api_.c_str(), - use_unpacked_api - } { + InMemoryMetadataNode() + : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs */, {} /* devices */, + "" /* executor */, "" /* mod_name */, "" /* interface_api */, + false /* use_unpacked_api */ + ) {} + InMemoryMetadataNode(int64_t version, + const ::std::vector<::tvm::runtime::metadata::TensorInfo>& inputs, + const ::std::vector<::tvm::runtime::metadata::TensorInfo>& outputs, + const ::std::vector<::std::string>& devices, + const ::tvm::runtime::String executor, const ::tvm::runtime::String mod_name, + const ::tvm::runtime::String interface_api, bool use_unpacked_api) + : VisitableMetadataNode{&storage_}, + inputs_{new struct TVMTensorInfo[inputs.size()]()}, + inputs_objs_{inputs}, + outputs_{new struct TVMTensorInfo[outputs.size()]()}, + outputs_objs_{outputs}, + devices_{new const char*[devices.size()]()}, + executor_{executor}, + mod_name_{mod_name}, + interface_api_{interface_api}, + storage_{version, + nullptr, + 0, + nullptr, + 0, + nullptr, + 0, + executor_.c_str(), + mod_name_.c_str(), + interface_api_.c_str(), + use_unpacked_api} { storage_.inputs = inputs_.get(); storage_.num_inputs = inputs.size(); for (unsigned int i = 0; i < inputs.size(); ++i) { @@ -149,7 +144,7 @@ class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNo class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode { public: explicit VisitableTensorInfoNode(const struct ::TVMTensorInfo* data) : TensorInfoNode{data} {} - explicit VisitableTensorInfoNode() : TensorInfoNode{nullptr} {} + VisitableTensorInfoNode() : TensorInfoNode{nullptr} {} void VisitAttrs(AttrVisitor* v) { ::std::string name_cpp{data()->name}; @@ -158,7 +153,7 @@ class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode auto shape_accessor = shape(); shape_array.reserve(num_shape()); for (int64_t i = 0; i < num_shape(); ++i) { - shape_array.push_back(::tvm::Integer{int(shape_accessor[i])}); + shape_array.push_back(::tvm::Integer{static_cast(shape_accessor[i])}); } ::tvm::runtime::metadata::MetadataArray shape_metadata_array{shape_array, "int64_t"}; v->Visit("shape", &shape_metadata_array); @@ -169,24 +164,13 @@ class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode class InMemoryTensorInfoNode : public ::tvm::target::metadata::VisitableTensorInfoNode { public: - InMemoryTensorInfoNode() : InMemoryTensorInfoNode( - "", - {}, - ::tvm::runtime::DataType(0, 0, 0) - ) {} - InMemoryTensorInfoNode( - const ::tvm::runtime::String& name, - const ::std::vector& shape, - ::tvm::runtime::DataType dtype - ) : - VisitableTensorInfoNode{&storage_}, - name_{name}, - shape_{new int64_t[shape.size()]()}, - storage_{ - name_.c_str(), - nullptr, 0, - dtype - } { + InMemoryTensorInfoNode() : InMemoryTensorInfoNode("", {}, ::tvm::runtime::DataType(0, 0, 0)) {} + InMemoryTensorInfoNode(const ::tvm::runtime::String& name, const ::std::vector& shape, + ::tvm::runtime::DataType dtype) + : VisitableTensorInfoNode{&storage_}, + name_{name}, + shape_{new int64_t[shape.size()]()}, + storage_{name_.c_str(), nullptr, 0, dtype} { storage_.shape = shape_.get(); storage_.num_shape = shape.size(); for (unsigned int i = 0; i < shape.size(); ++i) { @@ -201,7 +185,7 @@ class InMemoryTensorInfoNode : public ::tvm::target::metadata::VisitableTensorIn }; } // namespace metadata -} // namespace runtime +} // namespace target } // namespace tvm -#endif // TVM_TARGET_METADATA_H +#endif // TVM_TARGET_METADATA_H_ diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index bed78a5c2e2c..9e41c6c85c1f 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -35,14 +35,11 @@ namespace tvm { namespace codegen { - -static runtime::Module CreateCrtMetadataModule(runtime::Module target_module, Target target, - relay::Runtime runtime, - runtime::metadata::Metadata metadata, - Array non_crt_exportable_modules, - Array crt_exportable_modules, - const std::unordered_map& const_var_ndarray) { - +static runtime::Module CreateCrtMetadataModule( + runtime::Module target_module, Target target, relay::Runtime runtime, + runtime::metadata::Metadata metadata, Array non_crt_exportable_modules, + Array crt_exportable_modules, + const std::unordered_map& const_var_ndarray) { if (!non_crt_exportable_modules.empty()) { std::string non_exportable_modules; for (unsigned int i = 0; i < non_crt_exportable_modules.size(); i++) { @@ -55,7 +52,7 @@ static runtime::Module CreateCrtMetadataModule(runtime::Module target_module, Ta non_exportable_modules += pf_sym().operator std::string(); } else { non_exportable_modules += - std::string{"(module type_key="} + mod->type_key() + std::string{")"}; + std::string{"(module type_key="} + mod->type_key() + std::string{")"}; } } CHECK(false) << "These " << non_crt_exportable_modules.size() @@ -64,7 +61,8 @@ static runtime::Module CreateCrtMetadataModule(runtime::Module target_module, Ta if (target->kind->name == "c") { crt_exportable_modules.push_back(target_module); - target_module = CreateCSourceCrtMetadataModule(crt_exportable_modules, target, runtime, metadata); + target_module = + CreateCSourceCrtMetadataModule(crt_exportable_modules, target, runtime, metadata); } else if (target->kind->name == "llvm") { #ifdef TVM_LLVM_VERSION crt_exportable_modules.push_back(target_module); @@ -78,14 +76,15 @@ static runtime::Module CreateCrtMetadataModule(runtime::Module target_module, Ta } static runtime::Module CreateCppMetadataModule( - runtime::Module target_module, Target target, relay::Runtime runtime, - runtime::metadata::Metadata metadata, + runtime::Module target_module, Target target, relay::Runtime runtime, + runtime::metadata::Metadata metadata, const std::unordered_map>& const_vars_by_symbol, Array non_crt_exportable_modules, Array crt_exportable_modules, const std::unordered_map& const_var_ndarray) { if (!non_crt_exportable_modules.empty()) { - runtime::Module const_loader_mod = runtime::ConstLoaderModuleCreate(const_var_ndarray, const_vars_by_symbol); + runtime::Module const_loader_mod = + runtime::ConstLoaderModuleCreate(const_var_ndarray, const_vars_by_symbol); const_loader_mod.Import(target_module); for (const auto& it : non_crt_exportable_modules) { const_loader_mod.Import(it); @@ -166,7 +165,9 @@ runtime::Module CreateMetadataModule( } if (is_targeting_crt) { - return CreateCrtMetadataModule(target_module, target, runtime, metadata, non_crt_exportable_modules, crt_exportable_modules, const_var_ndarray); + return CreateCrtMetadataModule(target_module, target, runtime, metadata, + non_crt_exportable_modules, crt_exportable_modules, + const_var_ndarray); } else { return CreateCppMetadataModule(target_module, target, runtime, metadata, const_vars_by_symbol, non_crt_exportable_modules, crt_exportable_modules, @@ -174,7 +175,6 @@ runtime::Module CreateMetadataModule( } } - } // namespace codegen } // namespace tvm diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 8d8d8328ffa8..f4ba346cc44f 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,18 +23,21 @@ */ #include "source_module.h" -#include -#include -#include -#include "../metadata.h" - +#include #include #include +#include #include +#include + +#include +#include +#include #include "../../runtime/file_utils.h" #include "../../support/str_escape.h" #include "../func_registry_generator.h" +#include "../metadata.h" #include "codegen_source_base.h" namespace tvm { @@ -131,7 +134,8 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { public: CSourceCrtMetadataModuleNode(const Array& func_names, const std::string& fmt, - Target target, relay::Runtime runtime, runtime::metadata::Metadata metadata) + Target target, relay::Runtime runtime, + runtime::metadata::Metadata metadata) : fmt_(fmt), func_names_(func_names), target_(target), @@ -287,7 +291,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { auto devices = metadata_->devices(); for (const String& device : devices) { code_ << "devices->" << device; - if (device != devices[devices.size() -1]) { + if (device != devices[devices.size() - 1]) { code_ << ","; } } @@ -340,7 +344,6 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } }; - class CMetadataWriterVisitor : public ::tvm::AttrVisitor { private: std::stringstream struct_defs_; @@ -348,36 +351,28 @@ class CMetadataWriterVisitor : public ::tvm::AttrVisitor { std::vector streams_; std::stringstream* current_stream_; - void Visit(const char* key, double* value) override { - (*current_stream_) << *value; - } + void Visit(const char* key, double* value) override { (*current_stream_) << *value; } - void Visit(const char* key, int64_t* value) override { - (*current_stream_) << *value << "L"; - } + void Visit(const char* key, int64_t* value) override { (*current_stream_) << *value << "L"; } - void Visit(const char* key, uint64_t* value) override { - (*current_stream_) << *value << "UL"; - } + void Visit(const char* key, uint64_t* value) override { (*current_stream_) << *value << "UL"; } - void Visit(const char* key, int* value) override { - (*current_stream_) << *value; - } + void Visit(const char* key, int* value) override { (*current_stream_) << *value; } void Visit(const char* key, bool* value) override { (*current_stream_) << (value ? "true" : "false"); } void Visit(const char* key, std::string* value) override { - (*current_stream_) << "\"" << value << "\""; // todo: ->replace('\\', "\\\\").replace('\"', "\\\"") << "\""; + (*current_stream_) << "\"" << value + << "\""; // todo: ->replace('\\', "\\\\").replace('\"', "\\\"") << "\""; } - void Visit(const char* key, void** value) override { - (*current_stream_) << *value; - } + void Visit(const char* key, void** value) override { (*current_stream_) << *value; } void Visit(const char* key, DataType* value) override { - (*current_stream_) << "DLDataType{" << value->code() << ", " << value->bits() << ", " << value->lanes() << "}"; + (*current_stream_) << "DLDataType{" << value->code() << ", " << value->bits() << ", " + << value->lanes() << "}"; } void Visit(const char* key, runtime::NDArray* value) override { @@ -385,15 +380,13 @@ class CMetadataWriterVisitor : public ::tvm::AttrVisitor { } void Visit(const char* key, runtime::ObjectRef* value) override { -// if (value->as< + // if (value->as< // todo } - }; class MetadataStructDefiner : public AttrVisitor { public: - void Visit(const char* key, double* value) final { // dns: mangle name code_ << " double " << key << ";" << std::endl; @@ -457,7 +450,6 @@ class MetadataStructDefiner : public AttrVisitor { // } // } - // const ArrayNode* arr = value->as(); // if (arr != nullptr) { // // dns: mangle name @@ -494,16 +486,13 @@ class MetadataStructDefiner : public AttrVisitor { is_first_item_ = old_is_first_item; } - std::string GetOutput() { - return code_.str(); - } + std::string GetOutput() { return code_.str(); } private: ::std::stringstream code_; bool is_first_item_; }; - static std::string address_from_parts(const std::vector& parts) { std::stringstream ss; for (unsigned int i = 0; i < parts.size(); ++i) { @@ -518,7 +507,7 @@ static std::string address_from_parts(const std::vector& parts) { class MetadataQueuer : public AttrVisitor { public: using QueueItem = std::tuple; - MetadataQueuer(std::vector* queue) : queue_{queue} {} + explicit MetadataQueuer(std::vector* queue) : queue_{queue} {} void Visit(const char* key, double* value) final {} void Visit(const char* key, int64_t* value) final {} @@ -534,12 +523,14 @@ class MetadataQueuer : public AttrVisitor { address_parts_.push_back(key); if (value->as() != nullptr) { auto metadata = Downcast(*value); - const runtime::metadata::MetadataArrayNode* arr = value->as(); + const runtime::metadata::MetadataArrayNode* arr = + value->as(); std::cout << "Is array? " << arr << std::endl; if (arr != nullptr) { for (unsigned int i = 0; i < arr->array.size(); i++) { ObjectRef o = arr->array[i]; - std::cout << "queue-visiting array element " << i << ": " << o->type_index() << " (" << o.operator->() << ")" << std::endl; + std::cout << "queue-visiting array element " << i << ": " << o->type_index() << " (" + << o.operator->() << ")" << std::endl; if (o.as() != nullptr) { std::stringstream ss; ss << i; @@ -553,7 +544,8 @@ class MetadataQueuer : public AttrVisitor { ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); } - queue_->push_back(std::make_tuple(address_from_parts(address_parts_), Downcast(*value))); + queue_->push_back(std::make_tuple(address_from_parts(address_parts_), + Downcast(*value))); } address_parts_.pop_back(); } @@ -564,7 +556,7 @@ class MetadataQueuer : public AttrVisitor { }; class MetadataSerializer : public AttrVisitor { -public: + public: static constexpr const char* kGlobalSymbol = "kTvmgenMetadata"; MetadataSerializer() : is_first_item_{true} {} @@ -586,7 +578,7 @@ class MetadataSerializer : public AttrVisitor { void Visit(const char* key, double* value) final { WriteComma(); code_.setf(std::ios::hex | std::ios::showbase | std::ios::fixed | std::ios::scientific, - std::ios::basefield | std::ios::showbase | std::ios::floatfield); + std::ios::basefield | std::ios::showbase | std::ios::floatfield); code_ << *value; WriteKey(key); } @@ -624,8 +616,8 @@ class MetadataSerializer : public AttrVisitor { } void Visit(const char* key, DataType* value) final { WriteComma(); - code_ << "DLDataType{" << value->code() << ", " << value->bits() << ", " - << value->lanes() << "}"; + code_ << "DLDataType{" << value->code() << ", " << value->bits() << ", " << value->lanes() + << "}"; WriteKey(key); } @@ -635,12 +627,14 @@ class MetadataSerializer : public AttrVisitor { } void VisitArray(const runtime::metadata::MetadataArrayNode* array) { - std::cout << "visit array " << array << ": " << array->c_type << " " << array->array.size() << std::endl; + std::cout << "visit array " << array << ": " << array->c_type << " " << array->array.size() + << std::endl; auto old_is_first_item = is_first_item_; is_first_item_ = true; - for (unsigned int i = 0; i < array->array.size(); ++i) { //ObjectRef o : *(array->array)) { + for (unsigned int i = 0; i < array->array.size(); ++i) { // ObjectRef o : *(array->array)) { ObjectRef o = array->array[i]; - std::cout << "visiting array element " << i << ": " << o->type_index() << " (" << o.operator->() << ")" << std::endl; + std::cout << "visiting array element " << i << ": " << o->type_index() << " (" + << o.operator->() << ")" << std::endl; if (o->IsInstance()) { int64_t i = Downcast(o); Visit(nullptr, &i); @@ -660,24 +654,26 @@ class MetadataSerializer : public AttrVisitor { address_.push_back(i_str.str()); Visit(nullptr, &metadata); address_.pop_back(); -// ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + // ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); } is_first_item_ = old_is_first_item; } void Visit(const char* key, ObjectRef* value) final { - const runtime::metadata::MetadataArrayNode* arr = value->as(); + const runtime::metadata::MetadataArrayNode* arr = + value->as(); std::cout << "Is array? " << arr << std::endl; if (arr != nullptr) { WriteComma(); if (key != nullptr) { address_.push_back(key); } - code_ << address_from_parts(address_) << " , " << arr->array.size() << " /* " << key << "_size */"; + code_ << address_from_parts(address_) << " , " << arr->array.size() << " /* " << key + << "_size */"; if (key != nullptr) { address_.pop_back(); } -// VisitArray(key, Downcast(*value).operator->()); + // VisitArray(key, Downcast(*value).operator->()); // WriteComma(); // code_ << "{"; // if (arr->size() > 0) { @@ -721,10 +717,9 @@ class MetadataSerializer : public AttrVisitor { // } void CodegenMetadata(::tvm::runtime::metadata::Metadata metadata) { - decl_ - << "#include " << std::endl - << "#include " << std::endl - << "#include " << std::endl; + decl_ << "#include " << std::endl + << "#include " << std::endl + << "#include " << std::endl; std::vector queue; MetadataQueuer queuer{&queue}; queuer.Visit(kGlobalSymbol, &metadata); @@ -741,8 +736,8 @@ class MetadataSerializer : public AttrVisitor { if (strcmp(arr->c_type, "const char*") == 0) { const_part = ""; } - code_ << const_part << arr->c_type << " " << struct_name - << "[" << arr->array.size() << "] = {" << std::endl; + code_ << const_part << arr->c_type << " " << struct_name << "[" << arr->array.size() + << "] = {" << std::endl; VisitArray(arr); } else { code_ << "const struct TVMMetadata " << struct_name << " = {" << std::endl; @@ -753,11 +748,9 @@ class MetadataSerializer : public AttrVisitor { } } - std::string GetOutput() { - return decl_.str() + code_.str(); - } + std::string GetOutput() { return decl_.str() + code_.str(); } -private: + private: std::vector address_; std::stringstream decl_; std::stringstream code_; @@ -767,7 +760,8 @@ class MetadataSerializer : public AttrVisitor { }; runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, - relay::Runtime runtime, runtime::metadata::Metadata metadata) { + relay::Runtime runtime, + runtime::metadata::Metadata metadata) { Array func_names; for (runtime::Module mod : modules) { auto pf_funcs = mod.GetFunction("get_func_names"); @@ -787,8 +781,8 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array& mod } runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metadata) { -// MetadataStructDefiner definer; -// ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), &definer); + // MetadataStructDefiner definer; + // ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), &definer); MetadataSerializer serializer; serializer.CodegenMetadata(metadata); std::stringstream lookup_func; @@ -796,16 +790,20 @@ runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metad << "extern \"C\"\n" << "#endif\n"; - lookup_func << "TVM_DLL int32_t get_c_metadata(TVMValue* arg_values, int* arg_tcodes, int num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {" << std::endl; - lookup_func << " ret_values[0].v_handle = (void*) &" << MetadataSerializer::kGlobalSymbol << ";" << std::endl; + lookup_func << "TVM_DLL int32_t get_c_metadata(TVMValue* arg_values, int* arg_tcodes, int " + "num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {" + << std::endl; + lookup_func << " ret_values[0].v_handle = (void*) &" << MetadataSerializer::kGlobalSymbol + << ";" << std::endl; lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl; lookup_func << " return 0;" << std::endl; lookup_func << "};" << std::endl; auto mod = MetadataModuleCreate(metadata); std::vector func_names{"get_c_metadata"}; - //definer.GetOutput() + - auto c = CSourceModuleCreate(serializer.GetOutput() + lookup_func.str(), "c", func_names, Array()); + // definer.GetOutput() + + auto c = CSourceModuleCreate(serializer.GetOutput() + lookup_func.str(), "c", func_names, + Array()); mod->Import(c); return mod; } @@ -875,7 +873,8 @@ TVM_REGISTER_GLOBAL("runtime.CreateCSourceCrtMetadataModule") .set_body_typed([](const Array& modules, Target target, relay::Runtime runtime) { // Note that we don't need metadata when we compile a single operator - return CreateCSourceCrtMetadataModule(modules, target, runtime, runtime::metadata::Metadata()); + return CreateCSourceCrtMetadataModule(modules, target, runtime, + runtime::metadata::Metadata()); }); } // namespace codegen diff --git a/src/target/source/source_module.h b/src/target/source/source_module.h index c7d3302e64b4..9028f6dc410a 100644 --- a/src/target/source/source_module.h +++ b/src/target/source/source_module.h @@ -44,7 +44,8 @@ namespace codegen { * \return The wrapped module. */ runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, - relay::Runtime runtime, runtime::metadata::Metadata metadata); + relay::Runtime runtime, + runtime::metadata::Metadata metadata); /*! * \brief Create C++-runtime targeted metadata module for "c" backend. From b714c29059097cecc909bf68df50d86809ca17de Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 9 Dec 2021 17:22:28 -0800 Subject: [PATCH 15/41] DNS lint hacks, idk what's up here... --- include/tvm/runtime/metadata.h | 12 +++++------- src/runtime/aot_executor/aot_executor.cc | 3 ++- src/target/source/source_module.cc | 7 ++++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h index 438698373f31..3192ff997b8e 100644 --- a/include/tvm/runtime/metadata.h +++ b/include/tvm/runtime/metadata.h @@ -17,9 +17,6 @@ * under the License. */ -// NOTE: This file is intended to be compileable in C++ and C build processes. -// NOLINT(build/include_order) - /*! * \file tvm/runtime/metadata.h * \brief Defines types which can be used in Metadata. @@ -27,14 +24,15 @@ #ifndef TVM_RUNTIME_METADATA_H_ #define TVM_RUNTIME_METADATA_H_ +#include #include #include #include -#include -#include -#include -#include +// TODO(areusch): idk what's up here. +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) #define TVM_METADATA_VERSION 1 static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION; diff --git a/src/runtime/aot_executor/aot_executor.cc b/src/runtime/aot_executor/aot_executor.cc index 763fb39c3b4b..24a6a7328890 100644 --- a/src/runtime/aot_executor/aot_executor.cc +++ b/src/runtime/aot_executor/aot_executor.cc @@ -27,7 +27,8 @@ #include -#include +// TODO(areusch): idk what's up here... +#include // NOLINT(build/include_order) namespace tvm { namespace runtime { diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index f4ba346cc44f..1844a4b9efd5 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -30,9 +30,10 @@ #include #include -#include -#include -#include +// TODO(areusch): idk what's up here... +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) #include "../../runtime/file_utils.h" #include "../../support/str_escape.h" From f1fbed161b51a8f33a02dde0636fd636f8345cb6 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 9 Dec 2021 17:22:39 -0800 Subject: [PATCH 16/41] Fix python formatting --- python/tvm/runtime/executor/__init__.py | 17 +++++++++++++++++ tests/python/relay/aot/test_cpp_aot.py | 6 ++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/python/tvm/runtime/executor/__init__.py b/python/tvm/runtime/executor/__init__.py index 92a0402549ec..0748bbd00aec 100644 --- a/python/tvm/runtime/executor/__init__.py +++ b/python/tvm/runtime/executor/__init__.py @@ -1,2 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Top-level file for the executor module.""" from .aot_executor import AotModule diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index faee1d75b52e..4585a586346b 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -45,7 +45,8 @@ def print_mod_tree(m, indent=0): print_mod_tree(i, indent + 2) def test_conv2d(): - RELAY_MODEL = textwrap.dedent("""\ + RELAY_MODEL = textwrap.dedent( + """\ #[version = "0.0.5"] def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { %1 = nn.conv2d( @@ -59,7 +60,8 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), out_dtype="int32"); %1 } - """) + """ + ) ir_mod = tvm.parser.fromtext(RELAY_MODEL) main_func = ir_mod["main"] From d940826d57e0c5df3fbf8f3d832bdf8ef1880bdb Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Fri, 10 Dec 2021 08:28:47 -0800 Subject: [PATCH 17/41] git-clang-format --- include/tvm/runtime/metadata.h | 3 ++- include/tvm/runtime/metadata_base.h | 2 +- src/target/source/source_module.cc | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h index 3192ff997b8e..c29b701291b0 100644 --- a/include/tvm/runtime/metadata.h +++ b/include/tvm/runtime/metadata.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_METADATA_H_ #include + #include #include #include @@ -32,7 +33,7 @@ // TODO(areusch): idk what's up here. #include // NOLINT(build/include_order) #include // NOLINT(build/include_order) -#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) #define TVM_METADATA_VERSION 1 static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION; diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h index b155727ea404..228108215d09 100644 --- a/include/tvm/runtime/metadata_base.h +++ b/include/tvm/runtime/metadata_base.h @@ -29,8 +29,8 @@ #include #include -#include #include +#include namespace tvm { namespace runtime { diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 1844a4b9efd5..626274fd1877 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,17 +23,17 @@ */ #include "source_module.h" -#include #include +#include #include #include #include #include // TODO(areusch): idk what's up here... -#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) #include // NOLINT(build/include_order) -#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) #include "../../runtime/file_utils.h" #include "../../support/str_escape.h" From 96544e9dd3dd8544f2a750de694ad8c575b28274 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Fri, 10 Dec 2021 08:33:03 -0800 Subject: [PATCH 18/41] fix span.h --- include/tvm/support/span.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/support/span.h b/include/tvm/support/span.h index 4f2da9bf2c98..150798c488f0 100644 --- a/include/tvm/support/span.h +++ b/include/tvm/support/span.h @@ -19,7 +19,7 @@ /*! * - * \file span.h + * \file tvm/support/span.h * \brief Reimplementation of part of C++-20 style span. */ #ifndef TVM_SUPPORT_SPAN_H_ From 06e3b93baa6c58c7bf869d1a395e967a2d2800f3 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Fri, 10 Dec 2021 16:43:02 -0800 Subject: [PATCH 19/41] Move kTvmExecutor consts to runtime and fix improper references. --- include/tvm/relay/runtime.h | 6 ++++++ include/tvm/target/target_kind.h | 6 ------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/runtime.h b/include/tvm/relay/runtime.h index c4cabf5a5548..e4a1a63e45e2 100644 --- a/include/tvm/relay/runtime.h +++ b/include/tvm/relay/runtime.h @@ -44,6 +44,12 @@ class AttrRegistry; namespace relay { +/*! \brief Value used with Runtime::name to indicate the C++ runtime. */ +static constexpr const char* kTvmRuntimeCpp = "cpp"; + +/*! \brief Value used with Runtime::name to indicate the C runtime. */ +static constexpr const char* kTvmRuntimeCrt = "c"; + /*! * \brief Runtime information. * diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index e802a3088d2d..6e9c8445695c 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -162,12 +162,6 @@ class TargetKindAttrMap : public AttrRegistryMap { explicit TargetKindAttrMap(const AttrRegistryMapContainerMap& map) : TParent(map) {} }; -/*! \brief Value used with --runtime in target specs to indicate the C++ runtime. */ -static constexpr const char* kTvmRuntimeCpp = "c++"; - -/*! \brief Value used with --runtime in target specs to indicate the C runtime. */ -static constexpr const char* kTvmRuntimeCrt = "c"; - /*! * \brief Helper structure to register TargetKind * \sa TVM_REGISTER_TARGET_KIND From c250f163287704e31e94d0e3b23d95a53f1ec489 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Sun, 12 Dec 2021 20:48:07 -0800 Subject: [PATCH 20/41] git-clang-format --- src/relay/backend/aot_executor_codegen.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 145cee0f60ec..7e528c4fde3e 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -830,9 +830,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { << ") when targeting c runtime"; } else if (runtime_config->name == kTvmRuntimeCpp) { CHECK(use_unpacked_api_ == false && static_cast(use_call_cpacked_) == true) - << "Need unpacked-api == false (got: " << use_unpacked_api_ - << ") and interface-api == \"c\" (got: " << interface_api - << ") when targeting c++ runtime"; + << "Need unpacked-api == false (got: " << use_unpacked_api_ + << ") and interface-api == \"c\" (got: " << interface_api + << ") when targeting c++ runtime"; } else { ICHECK(false) << "runtime_config (" << runtime_config->name << ") is not one of the expected values"; From ff1bd7991b786520ff37b30924941e333747bf08 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Sun, 12 Dec 2021 20:50:26 -0800 Subject: [PATCH 21/41] black format --- tests/python/relay/aot/test_cpp_aot.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 4585a586346b..98e752cf628c 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -42,7 +42,8 @@ def print_mod_tree(m, indent=0): print(f"{' ' * indent} - {m!r}") for i in m.imported_modules: - print_mod_tree(i, indent + 2) + print_mod_tree(i, indent + 2) + def test_conv2d(): RELAY_MODEL = textwrap.dedent( @@ -76,8 +77,12 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), output_list = generate_ref_data(ir_mod, inputs, params) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = tvm.relay.build(ir_mod, params=params, target="c", - executor=backend.Executor("aot", {"interface-api": "c"})) + mod = tvm.relay.build( + ir_mod, + params=params, + target="c", + executor=backend.Executor("aot", {"interface-api": "c"}), + ) print_mod_tree(mod.module) From 7735395a0f50f8b70658afb0cd258edfb655d1cc Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Sun, 12 Dec 2021 21:33:56 -0800 Subject: [PATCH 22/41] git-clang-format --- src/relay/backend/aot_executor_codegen.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 7e528c4fde3e..1a179833ed1c 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -824,12 +824,13 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Validate choice of use_unpacked_api_ and use_call_cpacked_ if (runtime_config->name == kTvmRuntimeCrt) { - CHECK(interface_api == "c" || use_unpacked_api_ == false) + CHECK(interface_api == "c" || static_cast(use_unpacked_api_) == false) << "Either need interface_api == \"c\" (got: " << interface_api << ") or unpacked-api == false (got: " << use_unpacked_api_ << ") when targeting c runtime"; } else if (runtime_config->name == kTvmRuntimeCpp) { - CHECK(use_unpacked_api_ == false && static_cast(use_call_cpacked_) == true) + CHECK(static_cast(use_unpacked_api_) == false && + static_cast(use_call_cpacked_) == true) << "Need unpacked-api == false (got: " << use_unpacked_api_ << ") and interface-api == \"c\" (got: " << interface_api << ") when targeting c++ runtime"; From 1f84fdf278bc29154a3408796b886d5dbb4eec90 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Sun, 12 Dec 2021 22:25:45 -0800 Subject: [PATCH 23/41] Fix incongruity between kTvmRuntimeCrt constant --- include/tvm/relay/runtime.h | 2 +- src/relay/backend/runtime.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/runtime.h b/include/tvm/relay/runtime.h index e4a1a63e45e2..50f0b07e5005 100644 --- a/include/tvm/relay/runtime.h +++ b/include/tvm/relay/runtime.h @@ -48,7 +48,7 @@ namespace relay { static constexpr const char* kTvmRuntimeCpp = "cpp"; /*! \brief Value used with Runtime::name to indicate the C runtime. */ -static constexpr const char* kTvmRuntimeCrt = "c"; +static constexpr const char* kTvmRuntimeCrt = "crt"; /*! * \brief Runtime information. diff --git a/src/relay/backend/runtime.cc b/src/relay/backend/runtime.cc index 786d6f937f14..923c9b2d5f65 100644 --- a/src/relay/backend/runtime.cc +++ b/src/relay/backend/runtime.cc @@ -88,9 +88,9 @@ RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { /********** Register Runtimes and options **********/ -TVM_REGISTER_RUNTIME("crt").add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); -TVM_REGISTER_RUNTIME("cpp").add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); /********** Registry **********/ From 2732dc0dfa9a6dbc8bf8d89453621fe3c7fada44 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 4 Jan 2022 12:16:01 -0800 Subject: [PATCH 24/41] fix segfault with devices --- src/target/metadata.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/target/metadata.h b/src/target/metadata.h index 64a9bbf1d562..26dca0d9d079 100644 --- a/src/target/metadata.h +++ b/src/target/metadata.h @@ -98,6 +98,7 @@ class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNo outputs_{new struct TVMTensorInfo[outputs.size()]()}, outputs_objs_{outputs}, devices_{new const char*[devices.size()]()}, + devices_objs_{devices}, executor_{executor}, mod_name_{mod_name}, interface_api_{interface_api}, @@ -125,7 +126,7 @@ class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNo storage_.devices = devices_.get(); storage_.num_devices = devices.size(); for (unsigned int i = 0; i < devices.size(); ++i) { - devices_.get()[i] = devices[i].c_str(); + devices_.get()[i] = devices_objs_[i].c_str(); } } @@ -135,6 +136,7 @@ class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNo ::std::unique_ptr outputs_; std::vector<::tvm::runtime::metadata::TensorInfo> outputs_objs_; ::std::unique_ptr devices_; + std::vector<::std::string> devices_objs_; ::std::string executor_; ::std::string mod_name_; ::std::string interface_api_; From 0bedaad7fe2fb8044fa387bc9f254f23b4c2826b Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 4 Jan 2022 22:07:03 -0800 Subject: [PATCH 25/41] fix packed/c interface api restriction --- src/relay/backend/aot_executor_codegen.cc | 9 +++++---- tests/python/relay/aot/test_crt_aot.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 1a179833ed1c..805aad359630 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -824,10 +824,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Validate choice of use_unpacked_api_ and use_call_cpacked_ if (runtime_config->name == kTvmRuntimeCrt) { - CHECK(interface_api == "c" || static_cast(use_unpacked_api_) == false) - << "Either need interface_api == \"c\" (got: " << interface_api - << ") or unpacked-api == false (got: " << use_unpacked_api_ - << ") when targeting c runtime"; + if (interface_api == "c") { + CHECK(static_cast(use_unpacked_api_) == true) + << "When interface_api == \"c\", need unpacked-api == true (got: " + << use_unpacked_api_ << ") when targeting c runtime"; + } } else if (runtime_config->name == kTvmRuntimeCpp) { CHECK(static_cast(use_unpacked_api_) == false && static_cast(use_call_cpacked_) == true) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 8a2b1f1bb84d..fdcded355065 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -47,7 +47,7 @@ def test_error_c_interface_with_packed_api(): two = relay.add(relay.const(1), relay.const(1)) func = relay.Function([], two) - with pytest.raises(tvm.TVMError, match="Packed interface required for packed operators"): + with pytest.raises(tvm.TVMError, match='When interface_api == "c", need unpacked-api == true'): compile_and_run( AOTTestModel( module=IRModule.from_expr(func), inputs={}, outputs=generate_ref_data(func, {}) From f5a268fac8f484b7a48a37cb3deef90985138f28 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 4 Jan 2022 22:07:37 -0800 Subject: [PATCH 26/41] Only emit __tvm_module_ctx when using C++ runtime; breaks multiple-models case. --- src/target/source/codegen_c_host.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index c1a763023bb3..4b421cc8643f 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -23,6 +23,7 @@ #include "codegen_c_host.h" #include +#include #include #include #include @@ -441,7 +442,8 @@ runtime::Module BuildCHost(IRModule mod, Target target) { cg.AddFunction(aot_executor_fn); } - if (aot_executor_fn.defined()) { + relay::Runtime runtime = mod->GetAttr(tvm::attr::kRuntime).value(); + if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) { cg.InitGlobalContext(); } From 574da50145fb17297c436fc7e8c7fbd0bf281d82 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 4 Jan 2022 22:08:38 -0800 Subject: [PATCH 27/41] fixup! Return new Metadata from graph-level codegen. --- src/relay/backend/utils.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 07c829fbdf62..4a0221809a22 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -31,7 +31,6 @@ #include #include #include -#include #include #include From b16f605bd19101d762074e6cc1c3db6c7637675f Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 5 Jan 2022 12:05:49 -0800 Subject: [PATCH 28/41] fixup! Stack-allocate DLTensor instances when necessary. --- src/relay/backend/aot_executor_codegen.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 805aad359630..93d742ee238a 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -445,7 +445,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { func_call, GenerateDeviceHook(context, "Close"), })); - } else if (use_call_cpacked_) { + } else if (use_call_cpacked_ && !use_unpacked_api_) { // call_cpacked calling convention needs a blank context args.push_back(tir::make_zero(DataType::Handle())); tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); From 94bcb1cd8c7da38e37a4146877403ea853bbd465 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 5 Jan 2022 12:06:30 -0800 Subject: [PATCH 29/41] fixup! Stack-allocate DLTensor instances when necessary. --- src/relay/backend/aot_executor_codegen.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 93d742ee238a..ac287529efe0 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -820,7 +820,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { Integer workspace_byte_alignment = executor_config->GetAttr("workspace-byte-alignment").value_or(16); use_unpacked_api_ = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); - use_call_cpacked_ = Bool(interface_api == "c"); + use_call_cpacked_ = + (Bool(interface_api == "c") || + // for now, C runtime does not support calling functions on other devices. therefore, + // opt to call PackedFunc directly by name rather than TVMBackendGetFuncFromEnv. + runtime_config->name == kTvmRuntimeCrt); // Validate choice of use_unpacked_api_ and use_call_cpacked_ if (runtime_config->name == kTvmRuntimeCrt) { From 47bde106bd69b2ba432db09717e1399a95da2c2a Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 6 Jan 2022 15:55:35 -0800 Subject: [PATCH 30/41] fix aot executor codegen for C --- src/relay/backend/aot_executor_codegen.cc | 67 +++++++++++++++-------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index ac287529efe0..be43a3d365b4 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -326,33 +326,40 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } + /*! \brief Return a PrimExpr which contains the arg to be passed down to a PrimFunc. + * + * TODO(areusch): Document the various cases which could necessitate us synthesizing + * a DLTensor on stack. + */ PrimExpr MakeDLTensor(Expr relay_arg, TensorType ttype, PrimExpr data) { - for (Var v : input_vars_) { - if (v == relay_arg) { - return data; - } - } - for (int return_sid : return_sid_) { - auto return_expr = sids_table_[return_sid]; - if (return_expr == relay_arg) { - return data; - } - } return data; } - - void PushTuple(Tuple tuple, std::vector sids, Array args) { - CHECK_EQ(sids.size(), tuple->fields.size()) - << "Relay tuple does not map 1:1 into TIR; AOT can't handle this type of Relay Expr in a " - "CallNode."; + // for (Var v : input_vars_) { + // if (v == relay_arg) { + // return data; + // } + // } + // for (int return_sid : return_sid_) { + // auto return_expr = sids_table_[return_sid]; + // if (return_expr == relay_arg) { + // return data; + // } + // } + // return data; + // } + + void PushTuple(Expr tuple, std::vector sids, Array* args) { +// CHECK_EQ(sids.size(), tuple->fields.size()) +// << "Relay tuple does not map 1:1 into TIR; AOT can't handle this type of Relay Expr in a " +// "CallNode."; StorageInfo& sinfo = storage_device_map_[tuple]; for (unsigned int i = 0; i < sids.size(); ++i) { if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[i]) != return_sid_.end()) { - args.push_back(sids[i]); + args->push_back(sids[i]); } else { - args.push_back(MakeDLTensor( - tuple->fields[i], Downcast(tuple->fields[i]->checked_type()), sids[i])); + args->push_back(sids[i]); //MakeDLTensor( +// tuple->fields[i], Downcast(tuple->fields[i]->checked_type()), sids[i])); } } } @@ -379,8 +386,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { } else { auto sids = FindExpr(arg); if (sids.size() > 1) { - auto tuple = Downcast(arg); - PushTuple(tuple, sids, args); +// auto tuple = Downcast(arg); + PushTuple(arg, sids, &args); } else { StorageInfo& sinfo = storage_device_map_[arg]; if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]) != @@ -396,8 +403,19 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Pack the return(s) value. A call node can produce multiple outputs auto result_expr_sid = PackSid(result_expr); if (result_expr_sid.size() > 1) { - auto tuple = Downcast(result_expr); - PushTuple(tuple, result_expr_sid, args); + LOG(INFO) << "RESULT EXPR " << result_expr; + LOG(INFO) << "RESULT TYPE " << result_expr->checked_type(); + auto result_storage_device_map = storage_device_map_[result_expr]; + LOG(INFO) << "RESULT STORAGE DEVICE MAP " << result_storage_device_map; + std::stringstream rsid; + for (auto s : result_expr_sid) { + rsid << s << ","; + } + LOG(INFO) << "RESULT_EXPR SID " << rsid.str() << "(end)"; +// auto tuple = Downcast(result_expr); + + PushTuple(result_expr, result_expr_sid, &args); + } else { StorageInfo& sinfo = storage_device_map_[result_expr]; if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]) != @@ -821,10 +839,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { executor_config->GetAttr("workspace-byte-alignment").value_or(16); use_unpacked_api_ = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); use_call_cpacked_ = - (Bool(interface_api == "c") || + (!use_unpacked_api_ || // for now, C runtime does not support calling functions on other devices. therefore, // opt to call PackedFunc directly by name rather than TVMBackendGetFuncFromEnv. runtime_config->name == kTvmRuntimeCrt); + LOG(INFO) << "Use call cpacked? " << bool(use_call_cpacked_) << "; " << interface_api << ", unpacked=" << use_unpacked_api_; // Validate choice of use_unpacked_api_ and use_call_cpacked_ if (runtime_config->name == kTvmRuntimeCrt) { From b976a2ed569258aa43d038ea1c3f42e56f9bc7a3 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 10 Jan 2022 07:53:25 -0800 Subject: [PATCH 31/41] random logging code --- src/relay/backend/aot_executor_codegen.cc | 8 ++++++++ src/relay/backend/build_module.cc | 1 + src/target/metadata_module.cc | 1 + src/target/source/codegen_c_host.cc | 1 + src/target/source/source_module.cc | 1 + 5 files changed, 12 insertions(+) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index be43a3d365b4..ef6250c2682d 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -150,6 +150,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { sid->storage_sizes_in_bytes.begin(), sid->storage_sizes_in_bytes.end()); } + LOG(INFO) << "Visit tuple: " << GetRef(op); storage_device_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes); AssignReturnSid(expr); } @@ -158,6 +159,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { Expr expr = GetRef(op); auto sids = GetStorage(op->tuple); ICHECK_LT(static_cast(op->index), sids->storage_ids.size()); + LOG(INFO) << "Visit TupleGetItem: " << expr; storage_device_map_[expr] = StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]}, {sids->storage_sizes_in_bytes[op->index]}); @@ -173,7 +175,9 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { private: void AssignReturnSid(Expr e) { if (storage_device_map_.find(e) != storage_device_map_.end()) { + LOG(INFO) << "AssignReturnSid: is now " << e; StorageInfo& sinfo = storage_device_map_[e]; + LOG(INFO) << "AssignReturnSid: storage_device_map_ " << sinfo; return_ids_.clear(); for (auto sid : sinfo->storage_ids) { return_ids_.push_back(sid); @@ -249,6 +253,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { virtual_devices.push_back(virtual_device); storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); } + LOG(INFO) << "CreateStorage: " << expr; storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices), std::move(storage_sizes_in_bytes)); } @@ -476,6 +481,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { tir::Stmt body = tir::SeqStmt(create_func_call_stmts); stmts_.push_back(body); + LOG(INFO) << "Create func call " << body; } /*! @@ -865,6 +871,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // TODO(mbs): Plumb from compiler config VirtualDevice host_virtual_device = VirtualDevice::ForTarget(target_host_); + VLOG(1) << "relay mod:" << std::endl << PrettyPrint(mod); IRModule lowered_mod = tec::LowerTEPass( mod_name, @@ -923,6 +930,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } CollectDeviceVariables(lowered_mod->GetAttr>("device_contexts").value()); + VLOG(1) << "lowered_main_func:" << std::endl << PrettyPrint(lowered_main_func); VisitExpr(lowered_main_func->body); // Create the runner function. Please note that the function is not legal yet diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 35dcd7378382..39c56d66947d 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -410,6 +410,7 @@ class RelayBuildModule : public runtime::ModuleNode { Function func = Downcast(relay_module->Lookup("main")); IRModule func_module = WithAttrs(IRModule::FromExpr(func), {{tvm::attr::kExecutor, executor_}, {tvm::attr::kRuntime, runtime_}}); + LOG(INFO) << "Executor " << executor_; // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_->name); diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 9e41c6c85c1f..4a318f780a4d 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -165,6 +165,7 @@ runtime::Module CreateMetadataModule( } if (is_targeting_crt) { + LOG(INFO) << "Create CRT metadata: " << metadata.defined(); return CreateCrtMetadataModule(target_module, target, runtime, metadata, non_crt_exportable_modules, crt_exportable_modules, const_var_ndarray); diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 4b421cc8643f..4e21775a4a20 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -397,6 +397,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) { bool emit_asserts = false; CodeGenCHost cg; cg.Init(output_ssa, emit_asserts, target->str()); + LOG(INFO) << "CodegenCHost: " << mod; Map linked_params; bool found_linked_params = false; diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 626274fd1877..3bc6c2367f39 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -338,6 +338,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { CreateFuncRegistry(); GenerateCrtSystemLib(); } + LOG(INFO) << "Metadata " << metadata_.defined() << " exec " << metadata_->executor(); if (metadata_.defined() && metadata_->executor() == runtime::kTvmExecutorAot) { GenerateAOTDescriptor(); } From b8cb34e83fa13bbbeed1ff402ba59adb5f3aeee4 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 12 Jan 2022 12:21:08 -0800 Subject: [PATCH 32/41] LLVM serializer --- src/target/llvm/codegen_cpu.cc | 217 +++++++++++++++++++++++ src/target/llvm/llvm_module.cc | 34 ++++ tests/python/relay/aot/aot_test_utils.py | 1 + 3 files changed, 252 insertions(+) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 8f1d76a937bf..f97f2f34a50d 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -905,6 +905,223 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { return GetContextPtr(gv_tvm_parallel_barrier_); } +struct MetadataLlvmTypes { + llvm::Type* t_float64; + llvm::Type* t_uint8; + llvm::Type* t_int64; + llvm::Type* t_bool; + llvm::Type* t_cstring; + llvm::Type* t_void_p; + llvm::Type* t_data_type; + ::std::unordered_map<::std::string, llvm::StructType*> structs; +}; + +class MetadataTypeDefiner : public AttrVisitor { +public: + MetadataTypeDefiner(llvm::Context* ctx, struct MetadataLlvmTypes* llvm_types) : ctx_{ctx}, llvm_types_{llvm_types} {} + + void Visit(const char* key, double* value) final { + elements_.emplace_back(llvm_types_->t_float64); + } + void Visit(const char* key, int64_t* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, uint64_t* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, int* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, bool* value) final { + elements_.emplace_back(llvm_types_->t_bool); + } + void Visit(const char* key, std::string* value) final { + elements_.emplace_back(llvm_types_->t_cstring); + } + void Visit(const char* key, void** value) final { + elements_.emplace_back(llvm_types_->t_void_p); + } + void Visit(const char* key, DataType* value) final { + elements_.emplace_back(llvm_types_->t_data_type); + } + void Visit(const char* key, runtime::NDArray* value) final { + CHECK(false) << "Do not support serializing NDArray"; + } + +private: + void VisitMetadataBase(runtime::metadata::MetadataBase metadata) { + elements_.emplace_back(llvm::PointerType::get(llvm::StructType::create(*ctx_, metadata->get_name()))); + if (visited_.find(metadata->get_name()) != visited_.end()) { + return; + } + + if (to_visit_.find(metadata->get_name()) != to_visit_.end()) { + return; + } + to_visit_[metadata->get_name()] = metadata; + } + +public: + void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { + for (auto o : arr->array) { + if (o->IsInstance()) { + elements_.emplace_back(llvm::PointerType::get(llvm_types_->t_float64, *ctx_)); + } if (o->IsInstance()) { + elements_.emplace_back(llvm::PointerType::get(llvm_types_->t_int64, *ctx_)); + } else if (o->IsInstance()) { + elements_.emplace_back(llvm::PointerType::get(llvm_tpyes_->t_cstring, *ctx_)); + } else { + runtime::metadata::MetadataBase metadata = Downcast(o); + VisitMetadata(metadata); + } + } + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + VisitArray(arr); + return; + } + + runtime::metadata::MetadataBase metadata = Downcast(*value); + VisitMetadata(metadata); + } + + void DefineTypes(runtime::metadata::Metadata metadata) { + to_visit.insert(metadata); + + while (to_visit_.size() > 0) { + auto it = to_visit_.begin(); + runtime::metadata::MetadataBase node = (*it).second; + visited_.insert((*it).first) + to_visit_.erase(it); + ReflectionVTable::Global()->VisitAttrs(node->operator(), this); + types_->structs_[metadata->get_name()] = llvm::StructType::create(*ctx_, elements_, metadata->get_name()); + elements_.clear(); + } + } + + llvm::LLVMContext* ctx_; + struct MetadataLlvmTypes* types_; + ::std::unordered_set<::std::string> visited_; + ::std::unordered_map<::std::string, runtime::metadata::MetadataBase> to_visit_; + ::std::vector elements_; +} + +class MetadataSerializer : public AttrVisitor { +public: + void Visit(const char* key, double* value) final { + elements_.back().emplace_back(llvm::ConstantFP::get(llvm_types_->t_float64, *value)); + } + void Visit(const char* key, int64_t* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int, static_cast(*value), true /* isSigned */)); + } + void Visit(const char* key, uint64_t* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int, *value, false /* isSigned */)); + } + void Visit(const char* key, int* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int, *value, true /* isSigned */)); + } + void Visit(const char* key, bool* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_bool, static_cast(*value), false /* isSigned */)); + } + void Visit(const char* key, std::string* value) final { + elements_.back().emplace_back(GetConstString(*value)); + } + void Visit(const char* key, void** value) final { + CHECK(false) << "Do not support serializing void*"; + } + void Visit(const char* key, DataType* value) final { + elements_.back().emplace_back(llvm::ConstantStruct::get( + llvm_types_->t_data_type, + llvm::ConstantInt::get(llvm_types_->t_uint8, value->code(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->bits(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->lanes(), false /* isSigned */))); + } + + void Visit(const char* key, runtime::NDArray* value) final { + CHECK(false) << "Do not support serializing NDArray"; + } + + llvm::Constant* VisitMetadata(runtime::metadata::MetadataBase) { + elements_.emplace_back(std::vector()); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + auto elements = elements_.pop_back(); + return llvm::ConstantStruct::get(llvm_types_->structs[metadata->get_name()], elements); + } + + llvm::Constant* VisitArray(const runtime::metadata::MetadataArrayNode* arr) { + if (arr->array.size() == 0) { + + elements_.emplace_back(std::vector()); + for (auto o : arr->array) { + if (o->IsInstance()) { + Visit(nullptr, &(Downcast(o)->value)); + } if (o->IsInstance()) { + Visit(nullptr, &(Downcast(o)->value)); + } else if (o->IsInstance()) { + ::std::string value = Downcast(o); + Visit(nullptr, &value); + } else { + // nested array not possible. + runtime::metadata::MetadataBase metadata = Downcast(o); + VisitMetadata(metadata); + } + } + return llvm::ConstantArray::get(elements_.pop_back() + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + elements_.back().emplace_back(VisitArray(arr)); + return; + } + + runtime::metadata::MetadataBase metadata = Downcast(*value); + elements_.back.emplace_back(VisitMetadata(metadata)); + } + + llvm::Constant* Serialize(runtime::metadata::MetadataBase metadata) { + Visit(nullptr, &metadata); + return last_production_; + } + + MetadataLlvmTypes llvm_types_; + llvm::LLVMContext* ctx_; + llvm::Module* module_; + std::vector> elements_; + llvm::Constant* last_production_; +}; + +void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { + MetadataLLvmTypes llvm_types{ + .t_float64{t_float64_}, + .t_uint8(llvm::Type::getUint8Ty(*ctx_)), + .t_int64{t_int64_}, + .t_bool{llvm::Type::getInt8Ty(*ctx)}, + .t_cstring{t_char_->getPointerTo()}, + .t_void_p{t_void_p_} + .t_data_type{llvm::StructType::get("DLDataType", t_int8_, t_int8_, t_int8_)}, + }; + + MetadataTypeDefiner definer{ctx_, &llvm_types}; + definer.DefineTypes(metadata); + + MetadataSerializer serializer; + serializer.Serialize(metadata); + + llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_p_, {}, false); + function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + "get_c_metadata", module_.get()); + llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + builder_->SetInsertPoint(entry_point_entry); + builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_)); +} + void CodeGenCPU::DefineFunctionRegistry(Array func_names) { ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime"; Array symbols; diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index dc10d7885c25..71977dccc71c 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -512,6 +512,40 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob") return runtime::Module(n); }); +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target) { + InitializeLLVM(); + auto tm = GetLLVMTargetMachine(target); + bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); + bool target_c_runtime = runtime->name == "crt"; + auto ctx = std::make_shared(); + std::unique_ptr cg{new CodeGenCPU()}; + + cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, target_c_runtime); + + cg->DefineMetadata(metadata); + + mod->addModuleFlag(llvm::Module::Warning, "tvm_target", + llvm::MDString::get(*ctx, LLVMTargetToString(target))); + mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); + + if (tm->getTargetTriple().isOSDarwin()) { + mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); + } + + std::string verify_errors_storage; + llvm::raw_string_ostream verify_errors(verify_errors_storage); + LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors)) + << "LLVM module verification failed with the following errors: \n" + << verify_errors.str(); + + auto n = make_object(); + n->Init(std::move(mod), ctx); + for (auto m : modules) { + n->Import(m); + } + return runtime::Module(n); +} + runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, tvm::relay::Runtime runtime) { Array func_names; diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index d335528914b0..7dd65ca20ef4 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -206,6 +206,7 @@ def parametrize_aot_options(test): interface_api = ["packed", "c"] use_unpacked_api = [True, False] test_runner = [AOT_DEFAULT_RUNNER, AOT_CORSTONE300_RUNNER] + print("TEST RUNNERS", test_runner) all_combinations = itertools.product(interface_api, use_unpacked_api, test_runner) From d42e94d2db4ae6567a6e0f2a0505778566841ebc Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 Jan 2022 17:11:54 -0800 Subject: [PATCH 33/41] switch to encoding using MetadataTypeIndex --- include/tvm/runtime/metadata_base.h | 15 ++++--- src/runtime/metadata.cc | 4 +- src/target/metadata.h | 10 +++-- src/target/source/source_module.cc | 68 ++++++++++++++++++++--------- 4 files changed, 64 insertions(+), 33 deletions(-) diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h index 228108215d09..7229af58fbef 100644 --- a/include/tvm/runtime/metadata_base.h +++ b/include/tvm/runtime/metadata_base.h @@ -157,20 +157,21 @@ enum MetadataTypeIndex : uint8_t { kInt64 = 1, kBool = 2, kString = 3, - kHandle = 4, + kMetadata = 4, }; class MetadataArrayNode : public MetadataBaseNode { public: - // MetadataArray(Array array, MetadataTypeIndex type_index) : array{array}, - // type_index{type_index} {} - MetadataArrayNode(Array array, const char* c_type) - : array(std::move(array)), c_type{c_type} {} + MetadataArrayNode(Array array, MetadataTypeIndex type_index, const char* struct_name) : + array{array}, type_index{type_index}, struct_name{struct_name} {} +// MetadataArrayNode(Array array, const char* c_type) +// : array(std::move(array)), c_type{c_type} {} std::string get_name() override; Array array; - const char* c_type; + MetadataTypeIndex type_index; + const char* struct_name; static constexpr const char* _type_key = "metadata.MetadataArrayNode"; TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode); }; @@ -178,7 +179,7 @@ class MetadataArrayNode : public MetadataBaseNode { class MetadataArray : public MetadataBase { public: // MetadataArray(Array array, MetadataTypeIndex type_index); - MetadataArray(Array array, const char* c_type); + MetadataArray(Array array, MetadataTypeIndex type_index, const char* struct_name); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); }; diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index 2ba14ff84ef2..8415b5742dcd 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -33,8 +33,8 @@ namespace tvm { namespace runtime { namespace metadata { -MetadataArray::MetadataArray(Array array, const char* c_type) - : MetadataBase{make_object(array, c_type)} {} +MetadataArray::MetadataArray(Array array, MetadataTypeIndex type_index, const char* struct_name) + : MetadataBase{make_object(array, type_index, struct_name)} {} std::string MetadataArrayNode::get_name() { return "MetadataArray"; } diff --git a/src/target/metadata.h b/src/target/metadata.h index 26dca0d9d079..90b0c3c5cfea 100644 --- a/src/target/metadata.h +++ b/src/target/metadata.h @@ -49,7 +49,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { inputs_array.push_back(::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{inputs_array, - "struct TVMTensorInfo"}; + runtime::metadata::MetadataTypeIndex::kMetadata, + "TVMTensorInfo"}; v->Visit("inputs", &inputs_metadata_array); auto outputs_array = Array(); auto outputs_accessor = outputs(); @@ -58,7 +59,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { outputs_array.push_back(::tvm::runtime::metadata::TensorInfo{outputs_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{outputs_array, - "struct TVMTensorInfo"}; + runtime::metadata::MetadataTypeIndex::kMetadata, + "TVMTensorInfo"}; v->Visit("outputs", &outputs_metadata_array); auto devices_array = Array(); auto devices_accessor = devices(); @@ -66,7 +68,7 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { for (int64_t i = 0; i < num_devices(); ++i) { devices_array.push_back(::tvm::runtime::String{devices_accessor[i]}); } - ::tvm::runtime::metadata::MetadataArray devices_metadata_array{devices_array, "const char*"}; + ::tvm::runtime::metadata::MetadataArray devices_metadata_array{devices_array, runtime::metadata::MetadataTypeIndex::kString, "const char*"}; v->Visit("devices", &devices_metadata_array); ::std::string executor_cpp{data()->executor}; v->Visit("executor", &executor_cpp); @@ -157,7 +159,7 @@ class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode for (int64_t i = 0; i < num_shape(); ++i) { shape_array.push_back(::tvm::Integer{static_cast(shape_accessor[i])}); } - ::tvm::runtime::metadata::MetadataArray shape_metadata_array{shape_array, "int64_t"}; + ::tvm::runtime::metadata::MetadataArray shape_metadata_array{shape_array, runtime::metadata::MetadataTypeIndex::kInt64, "int64_t"}; v->Visit("shape", &shape_metadata_array); ::tvm::runtime::DataType dtype_cpp{dtype()}; v->Visit("dtype", &dtype_cpp); diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 3bc6c2367f39..df06858f0d4d 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -388,6 +388,8 @@ class CMetadataWriterVisitor : public ::tvm::AttrVisitor { }; class MetadataStructDefiner : public AttrVisitor { + using MetadataTypeIndex = runtime::metadata::MetadataTypeIndex; + public: void Visit(const char* key, double* value) final { // dns: mangle name @@ -436,21 +438,19 @@ class MetadataStructDefiner : public AttrVisitor { } void VisitArray(const char* key, const runtime::metadata::MetadataArrayNode* array) { - code_ << " " << array->c_type << "* " << key << ";" << std::endl; - } - // switch (array->type_index) { - // case MetadataTypeIndex::kUint64: - // code_ << " uint64_t** " << key << ";" << std::endl; - // case MetadataTypeIndex::kInt64: - // code_ << " int64_t** " << key << ";" << std::endl; - // case MetadataTypeIndex::kString: - // code_ << " const char** " << key << ";" << std::endl; - // case MetadataTypeIndex::kHandle: - // code_ << " void** " << key << ";" << std::endl; - // default: - // CHECK(false) << "Field " << key << ": unknown MetadataTypeIndex: " << array->type_index; - // } - // } + switch (array->type_index) { + case MetadataTypeIndex::kUint64: + code_ << " uint64_t** " << key << ";" << std::endl; + case MetadataTypeIndex::kInt64: + code_ << " int64_t** " << key << ";" << std::endl; + case MetadataTypeIndex::kBool: + code_ << " bool** " << key << ";" << std::endl; + case MetadataTypeIndex::kString: + code_ << " const char** " << key << ";" << std::endl; + default: + CHECK(false) << "Field " << key << ": unknown MetadataTypeIndex: " << array->type_index; + } + } // const ArrayNode* arr = value->as(); // if (arr != nullptr) { @@ -557,6 +557,31 @@ class MetadataQueuer : public AttrVisitor { std::vector address_parts_; }; +std::string MetadataArrayTypeToCType(const runtime::metadata::MetadataArrayNode* array) { + using MetadataTypeIndex = runtime::metadata::MetadataTypeIndex; + + switch (array->type_index) { + case MetadataTypeIndex::kInt64: + return "int64_t"; + break; + case MetadataTypeIndex::kUint64: + return "uint64_t"; + break; + case MetadataTypeIndex::kBool: + return "int8_t"; + break; + case MetadataTypeIndex::kString: + return "const char*"; + break; + case MetadataTypeIndex::kMetadata: + return ::std::string{"struct "} + array->struct_name; + break; + default: + ICHECK(false) << "Unexpected MetadataTypeIndex " << array->type_index; + return ""; + }; +} + class MetadataSerializer : public AttrVisitor { public: static constexpr const char* kGlobalSymbol = "kTvmgenMetadata"; @@ -629,7 +654,7 @@ class MetadataSerializer : public AttrVisitor { } void VisitArray(const runtime::metadata::MetadataArrayNode* array) { - std::cout << "visit array " << array << ": " << array->c_type << " " << array->array.size() + std::cout << "visit array " << array << ": " << array->type_index << " " << array->array.size() << std::endl; auto old_is_first_item = is_first_item_; is_first_item_ = true; @@ -734,11 +759,14 @@ class MetadataSerializer : public AttrVisitor { is_first_item_ = true; address_.push_back(struct_name); if (arr != nullptr) { - const char* const_part = "const "; - if (strcmp(arr->c_type, "const char*") == 0) { - const_part = ""; + std::string c_type{"const "}; + if (arr->type_index == runtime::metadata::MetadataTypeIndex::kString) { + // note drop const + c_type = MetadataArrayTypeToCType(arr); + } else { + c_type += MetadataArrayTypeToCType(arr); } - code_ << const_part << arr->c_type << " " << struct_name << "[" << arr->array.size() + code_ << c_type << "[" << arr->array.size() << "] = {" << std::endl; VisitArray(arr); } else { From 5a5285945db7f4904b3217f6316931545ba175a8 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 Jan 2022 17:13:18 -0800 Subject: [PATCH 34/41] Implement LLVM serializer. --- src/relay/backend/aot_executor_codegen.cc | 2 +- src/target/llvm/codegen_cpu.cc | 114 ++++++++++++++-------- src/target/llvm/codegen_cpu.h | 5 + src/target/llvm/codegen_llvm.h | 5 +- src/target/llvm/llvm_module.cc | 11 +-- src/target/llvm/llvm_module.h | 3 + src/target/metadata_module.cc | 4 + 7 files changed, 91 insertions(+), 53 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index ef6250c2682d..239a31060b7a 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -859,7 +859,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { << use_unpacked_api_ << ") when targeting c runtime"; } } else if (runtime_config->name == kTvmRuntimeCpp) { - CHECK(static_cast(use_unpacked_api_) == false && + CHECK(static_cast(use_unpacked_api_) == true || static_cast(use_call_cpacked_) == true) << "Need unpacked-api == false (got: " << use_unpacked_api_ << ") and interface-api == \"c\" (got: " << interface_api diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index f97f2f34a50d..ca813544f0b8 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -912,13 +912,13 @@ struct MetadataLlvmTypes { llvm::Type* t_bool; llvm::Type* t_cstring; llvm::Type* t_void_p; - llvm::Type* t_data_type; + llvm::StructType* t_data_type; ::std::unordered_map<::std::string, llvm::StructType*> structs; }; class MetadataTypeDefiner : public AttrVisitor { public: - MetadataTypeDefiner(llvm::Context* ctx, struct MetadataLlvmTypes* llvm_types) : ctx_{ctx}, llvm_types_{llvm_types} {} + MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* llvm_types) : ctx_{ctx}, llvm_types_{llvm_types} {} void Visit(const char* key, double* value) final { elements_.emplace_back(llvm_types_->t_float64); @@ -950,7 +950,7 @@ class MetadataTypeDefiner : public AttrVisitor { private: void VisitMetadataBase(runtime::metadata::MetadataBase metadata) { - elements_.emplace_back(llvm::PointerType::get(llvm::StructType::create(*ctx_, metadata->get_name()))); + elements_.emplace_back(llvm::PointerType::getUnqual(llvm::StructType::create(*ctx_, metadata->get_name()))); if (visited_.find(metadata->get_name()) != visited_.end()) { return; } @@ -964,15 +964,15 @@ class MetadataTypeDefiner : public AttrVisitor { public: void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { for (auto o : arr->array) { - if (o->IsInstance()) { - elements_.emplace_back(llvm::PointerType::get(llvm_types_->t_float64, *ctx_)); + if (o->IsInstance()) { + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_float64)); } if (o->IsInstance()) { - elements_.emplace_back(llvm::PointerType::get(llvm_types_->t_int64, *ctx_)); + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64)); } else if (o->IsInstance()) { - elements_.emplace_back(llvm::PointerType::get(llvm_tpyes_->t_cstring, *ctx_)); + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring)); } else { runtime::metadata::MetadataBase metadata = Downcast(o); - VisitMetadata(metadata); + VisitMetadataBase(metadata); } } } @@ -986,49 +986,52 @@ class MetadataTypeDefiner : public AttrVisitor { } runtime::metadata::MetadataBase metadata = Downcast(*value); - VisitMetadata(metadata); + VisitMetadataBase(metadata); } void DefineTypes(runtime::metadata::Metadata metadata) { - to_visit.insert(metadata); + to_visit_[metadata->get_name()] = metadata; while (to_visit_.size() > 0) { auto it = to_visit_.begin(); runtime::metadata::MetadataBase node = (*it).second; - visited_.insert((*it).first) + visited_.insert((*it).first); to_visit_.erase(it); - ReflectionVTable::Global()->VisitAttrs(node->operator(), this); - types_->structs_[metadata->get_name()] = llvm::StructType::create(*ctx_, elements_, metadata->get_name()); + ReflectionVTable::Global()->VisitAttrs(node.operator->(), this); + llvm_types_->structs[metadata->get_name()] = llvm::StructType::create(*ctx_, elements_, metadata->get_name()); elements_.clear(); } } llvm::LLVMContext* ctx_; - struct MetadataLlvmTypes* types_; + struct MetadataLlvmTypes* llvm_types_; ::std::unordered_set<::std::string> visited_; ::std::unordered_map<::std::string, runtime::metadata::MetadataBase> to_visit_; ::std::vector elements_; -} +}; class MetadataSerializer : public AttrVisitor { + using MetadataTypeIndex = runtime::metadata::MetadataTypeIndex; public: + MetadataSerializer(CodeGenLLVM* codegen, struct MetadataLlvmTypes* llvm_types) : codegen_{codegen}, llvm_types_{llvm_types} {} + void Visit(const char* key, double* value) final { elements_.back().emplace_back(llvm::ConstantFP::get(llvm_types_->t_float64, *value)); } void Visit(const char* key, int64_t* value) final { - elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int, static_cast(*value), true /* isSigned */)); + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int64, static_cast(*value), true /* isSigned */)); } void Visit(const char* key, uint64_t* value) final { - elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int, *value, false /* isSigned */)); + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int64, *value, false /* isSigned */)); } void Visit(const char* key, int* value) final { - elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int, *value, true /* isSigned */)); + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int64, *value, true /* isSigned */)); } void Visit(const char* key, bool* value) final { - elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_bool, static_cast(*value), false /* isSigned */)); + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_uint8, static_cast(*value), false /* isSigned */)); } void Visit(const char* key, std::string* value) final { - elements_.back().emplace_back(GetConstString(*value)); + elements_.back().emplace_back(codegen_->GetConstString(*value)); } void Visit(const char* key, void** value) final { CHECK(false) << "Do not support serializing void*"; @@ -1036,31 +1039,53 @@ class MetadataSerializer : public AttrVisitor { void Visit(const char* key, DataType* value) final { elements_.back().emplace_back(llvm::ConstantStruct::get( llvm_types_->t_data_type, - llvm::ConstantInt::get(llvm_types_->t_uint8, value->code(), false /* isSigned */), - llvm::ConstantInt::get(llvm_types_->t_uint8, value->bits(), false /* isSigned */), - llvm::ConstantInt::get(llvm_types_->t_uint8, value->lanes(), false /* isSigned */))); + {llvm::ConstantInt::get(llvm_types_->t_uint8, value->code(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->bits(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->lanes(), false /* isSigned */)})); } void Visit(const char* key, runtime::NDArray* value) final { CHECK(false) << "Do not support serializing NDArray"; } - llvm::Constant* VisitMetadata(runtime::metadata::MetadataBase) { + llvm::Constant* VisitMetadata(runtime::metadata::MetadataBase metadata) { elements_.emplace_back(std::vector()); ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); - auto elements = elements_.pop_back(); - return llvm::ConstantStruct::get(llvm_types_->structs[metadata->get_name()], elements); + auto struct_elements = elements_.back(); + elements_.pop_back(); + return llvm::ConstantStruct::get(llvm_types_->structs[metadata->get_name()], struct_elements); } llvm::Constant* VisitArray(const runtime::metadata::MetadataArrayNode* arr) { - if (arr->array.size() == 0) { + llvm::Type* element_type; + switch (arr->type_index) { + case MetadataTypeIndex::kInt64: + element_type = llvm_types_->t_int64; + break; + case MetadataTypeIndex::kUint64: + element_type = llvm_types_->t_int64; + break; + case MetadataTypeIndex::kBool: + element_type = llvm_types_->t_uint8; + break; + case MetadataTypeIndex::kString: + element_type = llvm_types_->t_cstring; + break; + case MetadataTypeIndex::kMetadata: + element_type = llvm_types_->structs[arr->struct_name]; + break; + default: + LOG(FATAL) << "unknown metadata type_index " << arr->type_index; + } elements_.emplace_back(std::vector()); for (auto o : arr->array) { if (o->IsInstance()) { - Visit(nullptr, &(Downcast(o)->value)); + double value = Downcast(o)->value;; + Visit(nullptr, &value); } if (o->IsInstance()) { - Visit(nullptr, &(Downcast(o)->value)); + auto value = Downcast(o)->value; + Visit(nullptr, &value); } else if (o->IsInstance()) { ::std::string value = Downcast(o); Visit(nullptr, &value); @@ -1070,7 +1095,9 @@ class MetadataSerializer : public AttrVisitor { VisitMetadata(metadata); } } - return llvm::ConstantArray::get(elements_.pop_back() + auto array = elements_.back(); + elements_.pop_back(); + return llvm::ConstantArray::get(llvm::ArrayType::get(element_type, array.size()), array); } void Visit(const char* key, ObjectRef* value) final { @@ -1082,7 +1109,7 @@ class MetadataSerializer : public AttrVisitor { } runtime::metadata::MetadataBase metadata = Downcast(*value); - elements_.back.emplace_back(VisitMetadata(metadata)); + elements_.back().emplace_back(VisitMetadata(metadata)); } llvm::Constant* Serialize(runtime::metadata::MetadataBase metadata) { @@ -1090,7 +1117,8 @@ class MetadataSerializer : public AttrVisitor { return last_production_; } - MetadataLlvmTypes llvm_types_; + CodeGenLLVM* codegen_; + MetadataLlvmTypes* llvm_types_; llvm::LLVMContext* ctx_; llvm::Module* module_; std::vector> elements_; @@ -1098,28 +1126,28 @@ class MetadataSerializer : public AttrVisitor { }; void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { - MetadataLLvmTypes llvm_types{ - .t_float64{t_float64_}, - .t_uint8(llvm::Type::getUint8Ty(*ctx_)), - .t_int64{t_int64_}, - .t_bool{llvm::Type::getInt8Ty(*ctx)}, - .t_cstring{t_char_->getPointerTo()}, - .t_void_p{t_void_p_} - .t_data_type{llvm::StructType::get("DLDataType", t_int8_, t_int8_, t_int8_)}, + MetadataLlvmTypes llvm_types{ + t_float64_ /* t_float64 */, + llvm::Type::getInt8Ty(*ctx_) /* t_uint8 */, + t_int64_ /* t_int64 */, + llvm::Type::getInt8Ty(*ctx_) /* t_bool */, + t_char_->getPointerTo() /* t_cstring */, + t_void_p_ /* t_void_p */, + llvm::StructType::create(*ctx_, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, }; MetadataTypeDefiner definer{ctx_, &llvm_types}; definer.DefineTypes(metadata); - MetadataSerializer serializer; - serializer.Serialize(metadata); + MetadataSerializer serializer{this, &llvm_types}; + llvm::Constant* metadata_constant = serializer.Serialize(metadata); llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_p_, {}, false); function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "get_c_metadata", module_.get()); llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(entry_point_entry); - builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_)); + builder_->CreateRet(builder_->CreateBitCast(metadata_constant, t_void_p_)); } void CodeGenCPU::DefineFunctionRegistry(Array func_names) { diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 58e314ec0c6e..6b70d65b3747 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -56,6 +56,11 @@ class CodeGenCPU : public CodeGenLLVM { */ void DefineFunctionRegistry(Array func_names); + /*! + * \brief Serialize the metadata object as data, and implement get_c_metadata function. + * \param metadata The metadata which should be serialized. + */ + void DefineMetadata(runtime::metadata::Metadata metadata); protected: void AddStartupFunction() final; // meta data diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 4a9df65951c0..34b94902551c 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -187,6 +187,9 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; + // Get constant string + llvm::Constant* GetConstString(const std::string& str); + protected: /*! * \brief Address and type pair to assist in handling opaque pointers. @@ -298,8 +301,6 @@ class CodeGenLLVM : public ExprFunctor, // Get alignment given index. void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, int* p_native_bits); - // Get constant string - llvm::Constant* GetConstString(const std::string& str); // do a scalarize call with f llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f, const std::vector& args); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 71977dccc71c..a7cc863cbedd 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -512,18 +512,18 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob") return runtime::Module(n); }); -runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target) { +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, + tvm::relay::Runtime runtime) { InitializeLLVM(); auto tm = GetLLVMTargetMachine(target); bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); - bool target_c_runtime = runtime->name == "crt"; auto ctx = std::make_shared(); std::unique_ptr cg{new CodeGenCPU()}; - cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, target_c_runtime); + cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, false /* target_c_runtime */); cg->DefineMetadata(metadata); - + auto mod = cg->Finish(); mod->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx, LLVMTargetToString(target))); mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); @@ -540,9 +540,6 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata auto n = make_object(); n->Init(std::move(mod), ctx); - for (auto m : modules) { - n->Import(m); - } return runtime::Module(n); } diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 933030e213d2..660d81400b0d 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -33,6 +33,9 @@ namespace tvm { namespace codegen { +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, + tvm::relay::Runtime runtime); + runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, tvm::relay::Runtime runtime); diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 4a318f780a4d..98e15ef7fa6a 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -97,6 +97,10 @@ static runtime::Module CreateCppMetadataModule( auto metadata_module = CreateCSourceCppMetadataModule(metadata); metadata_module->Import(target_module); target_module = metadata_module; + } else if (target->kind->name == "llvm") { + auto metadata_module = CreateLLVMCppMetadataModule(metadata, target, runtime); + metadata_module->Import(target_module); + target_module = metadata_module; } else { CHECK(false) << "Don't know how to create MetadataModule for target type " << target->str(); } From 3deabb205666efc25594b8e338cb8b8705ffa1f1 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 26 Jan 2022 13:40:36 -0800 Subject: [PATCH 35/41] checkpoint --- include/tvm/support/span.h | 18 ++++-- src/relay/backend/aot_executor_codegen.cc | 3 +- src/target/source/source_module.cc | 1 + tests/cpp/build_module_test.cc | 16 ++++-- tests/cpp/test_metadata.cc | 67 +++++++++++++++++++++++ tests/python/relay/aot/test_cpp_aot.py | 26 +++++++-- tests/python/unittest/test_metadata.py | 20 +++++++ 7 files changed, 133 insertions(+), 18 deletions(-) create mode 100644 tests/cpp/test_metadata.cc create mode 100644 tests/python/unittest/test_metadata.py diff --git a/include/tvm/support/span.h b/include/tvm/support/span.h index 150798c488f0..2bd94b8bf338 100644 --- a/include/tvm/support/span.h +++ b/include/tvm/support/span.h @@ -35,6 +35,12 @@ namespace support { template class Span { public: + using value_type = W; + using reference = W&; + using const_reference = const W&; + using pointer = W*; + using const_pointer = const W*; + class iterator : public std::iterator { public: inline iterator(T* ptr, T* end) : ptr_{ptr}, end_{end} { CHECK_GE(end, ptr); } @@ -46,21 +52,23 @@ class Span { return *this; } - inline bool operator==(iterator other) { return ptr_ == other.ptr_ && end_ == other.end_; } + inline bool operator==(iterator other) const { return ptr_ == other.ptr_ && end_ == other.end_; } - inline bool operator!=(iterator other) { return !(*this == other); } + inline bool operator!=(iterator other) const { return !(*this == other); } - private: + protected: T* ptr_; T* end_; }; + using const_iterator = iterator; + inline Span(T* begin, int num_elements) : begin_{begin}, end_{begin + num_elements} {} inline Span(T* begin, T* end) : begin_{begin}, end_{end} {} - inline iterator begin() { return iterator(begin_, end_); } + inline iterator begin() const { return iterator(begin_, end_); } - inline iterator end() { return iterator(end_, end_); } + inline iterator end() const { return iterator(end_, end_); } inline W operator[](int i) { T* to_return = begin_ + i; diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 239a31060b7a..7486a683e186 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -253,7 +253,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { virtual_devices.push_back(virtual_device); storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); } - LOG(INFO) << "CreateStorage: " << expr; +// LOG(INFO) << "CreateStorage: " << expr; storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices), std::move(storage_sizes_in_bytes)); } @@ -337,6 +337,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { * a DLTensor on stack. */ PrimExpr MakeDLTensor(Expr relay_arg, TensorType ttype, PrimExpr data) { + LOG(INFO) << "MakeDLTensor: " << relay_arg << " (ttype " << ttype << "): " << data; return data; } // for (Var v : input_vars_) { diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index df06858f0d4d..2b4780899fb4 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -317,6 +317,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "#endif\n"; if (metadata_->use_unpacked_api()) { + LOG(INFO) << "Generate AOT Descriptor: " << metadata_->interface_api(); if (metadata_->interface_api() == "c") { GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name()); } else { diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index d5a4c91a3c43..ff3641cd6982 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -107,18 +107,22 @@ TEST(BuildModule, Heterogeneous) { auto elemwise_sub = compute( C->shape, [©, &C](PrimExpr i) { return copy[i] - C[i]; }, "elemwise_sub"); - With cuda_scope(target_cuda); - auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); + auto fcreate_s1 = [=]() { + With cuda_scope(target_cuda); + return topi::cuda::schedule_injective(target_cuda, {elemwise_add}); + }; - With llvm_scope(target_llvm); - auto s2 = create_schedule({elemwise_sub->op}); + auto fcreate_s2 = [=]() { + With llvm_scope(target_llvm); + return create_schedule({elemwise_sub->op}); + }; auto args1 = Array({A, B, elemwise_add}); auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - auto lowered_s1 = LowerSchedule(s1, args1, "elemwise_add", binds); - auto lowered_s2 = LowerSchedule(s2, args2, "elemwise_sub", binds); + auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds); + auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds); Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; auto module = build(inputs, Target()); diff --git a/tests/cpp/test_metadata.cc b/tests/cpp/test_metadata.cc new file mode 100644 index 000000000000..be8c4fb8be62 --- /dev/null +++ b/tests/cpp/test_metadata.cc @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace { +const int64_t kNormalInput1Shape[4] = {1, 5, 5, 3}; +const struct TVMTensorInfo kNormalInputs[1] = {{"input1", kNormalInput1Shape, 4, DLDataType{1, 2, 3}}}; + +const int64_t kNormalOutput1Shape[3] = {3, 8, 8}; +const struct TVMTensorInfo kNormalOutputs[1] = {{"output1", kNormalOutput1Shape, 3, DLDataType{3, 4, 5}}}; + +const char* kNormalDevices[2] = {"device1", "device2"}; + +const struct TVMMetadata kNormal = { + TVM_METADATA_VERSION, + kNormalInputs, + 1, + kNormalOutputs, + 1, + kNormalDevices, + 2, + "aot", + "default", + "c", + true, + }; +} + +using ::testing::Eq; +using ::testing::ElementsAre; + +TEST(Metadata, ParseStruct) { + tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + EXPECT_THAT(md->version(), Eq(TVM_METADATA_VERSION)); + EXPECT_THAT(md->num_inputs(), Eq(1)); + + auto input1 = md->inputs()[0]; + EXPECT_THAT(input1->name(), Eq("input1")); + EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); + EXPECT_THAT(input1->dtype(), Eq(tvm::runtime::DataType(DLDataType{1, 2, 3}))); + // auto md_inputs = md->inputs(); + // EXPECT_EQ(md_inputs.size(), 1); + + // auto md_input = md_inputs[0]; + + // EXPECT_EQ(md->get_name(), kNormal.name); +// EXPECT_EQ(md->get_name(), kNormal.name); +} diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 98e752cf628c..788d36a014a2 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -45,21 +45,35 @@ def print_mod_tree(m, indent=0): print_mod_tree(i, indent + 2) -def test_conv2d(): +unpacked_api = tvm.testing.parameter(True, False) + + +def test_conv2d(unpacked_api): RELAY_MODEL = textwrap.dedent( """\ #[version = "0.0.5"] - def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { + def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), int8]) { %1 = nn.conv2d( %data, %weight, padding=[2, 2], - channels=8, + channels=3, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %2 = cast(nn.max_pool2d(%1, pool_size=[3, 3]), dtype="int8"); + %3 = nn.conv2d( + %2, + %weight, + padding=[2, 2], + channels=3, kernel_size=[5, 5], data_layout="NCHW", kernel_layout="OIHW", out_dtype="int32"); - %1 + %4 = nn.max_pool2d(%3, pool_size=[3, 3]); + %4 } """ ) @@ -81,13 +95,13 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), ir_mod, params=params, target="c", - executor=backend.Executor("aot", {"interface-api": "c"}), + executor=backend.Executor("aot", {"unpacked-api": unpacked_api, "interface-api": "packed"}), ) print_mod_tree(mod.module) with tvm.contrib.utils.TempDirectory.set_keep_for_debug(): - mod.export_library("test.so") + mod.export_library("test.so", options=["-fpermissive"]) mod.export_library("test.tar") runner = tvm.runtime.load_module("test.so") print_mod_tree(runner) diff --git a/tests/python/unittest/test_metadata.py b/tests/python/unittest/test_metadata.py new file mode 100644 index 000000000000..98e912cbfad4 --- /dev/null +++ b/tests/python/unittest/test_metadata.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm + + +def test_ From fb9bcdda81f807e018194cbb70bc103191d6340e Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Fri, 4 Feb 2022 15:58:21 -0800 Subject: [PATCH 36/41] emit cpacked_lowered in llvm codegen --- src/target/llvm/codegen_llvm.cc | 6 +++++- src/target/llvm/codegen_llvm.h | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6c64f6798e47..65ced34b801e 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1261,7 +1261,11 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); - if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { + if (op->op.same_as(builtin_tvm_call_cpacked_lowered_)) { + auto global_symbol = Downcast(op->args[0]); + return this->CreateCallExtern(GetType(GetRef(op)), global_symbol->value, op->args, + true); + } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic ICHECK_GE(op->args.size(), 1U); auto global_symbol = Downcast(op->args[0]); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 34b94902551c..cf548bde1606 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -392,6 +392,7 @@ class CodeGenLLVM : public ExprFunctor, const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin(); + const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); /*! \brief Helper struct for debug infos. */ struct DebugInfo { From 0b6268048a63aeb147a543c6a4af6060d80c721e Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Fri, 4 Feb 2022 16:21:10 -0800 Subject: [PATCH 37/41] emit tir::lookup_param node in llvm codegen --- src/target/llvm/codegen_llvm.cc | 5 ++++- src/target/llvm/codegen_llvm.h | 1 + src/target/metadata_module.cc | 8 ++++++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 65ced34b801e..d1adc203749a 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -237,7 +237,7 @@ void CodeGenLLVM::LinkParameters(const Map params) { auto array = NDArrayToLLVMArray(ctx_, kv.second->param); std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first; llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( - *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); + *module_, array->getType(), true, llvm::GlobalValue::ExternalLinkage, array, symbol_name); auto dtype = tvm::runtime::DataType(kv.second->param->dtype); size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment); #if TVM_LLVM_VERSION >= 100 @@ -1265,6 +1265,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { auto global_symbol = Downcast(op->args[0]); return this->CreateCallExtern(GetType(GetRef(op)), global_symbol->value, op->args, true); + } else if (op->op.same_as(builtin_lookup_param_)) { + std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + Downcast(op->args[0])->value; //operator std::string(); + return module_->getGlobalVariable(symbol_name); } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic ICHECK_GE(op->args.size(), 1U); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index cf548bde1606..52b94618caa3 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -392,6 +392,7 @@ class CodeGenLLVM : public ExprFunctor, const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin(); + const Op& builtin_lookup_param_ = builtin::lookup_param(); const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); /*! \brief Helper struct for debug infos. */ diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 98e15ef7fa6a..9e1a8dcf69e1 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -97,11 +97,15 @@ static runtime::Module CreateCppMetadataModule( auto metadata_module = CreateCSourceCppMetadataModule(metadata); metadata_module->Import(target_module); target_module = metadata_module; - } else if (target->kind->name == "llvm") { + } +#ifdef TVM_LLVM_VERSION + else if (target->kind->name == "llvm") { auto metadata_module = CreateLLVMCppMetadataModule(metadata, target, runtime); metadata_module->Import(target_module); target_module = metadata_module; - } else { + } +#endif // TVM_LLVM_VERSION + else { CHECK(false) << "Don't know how to create MetadataModule for target type " << target->str(); } } From 3d28fd1a19cdeb65d8dd29ba7f8950346abe3f1f Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Fri, 4 Feb 2022 16:21:47 -0800 Subject: [PATCH 38/41] expand test to cover llvm --- tests/python/relay/aot/test_cpp_aot.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 788d36a014a2..1a290caade7e 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -48,7 +48,10 @@ def print_mod_tree(m, indent=0): unpacked_api = tvm.testing.parameter(True, False) -def test_conv2d(unpacked_api): +target_kind = tvm.testing.parameter("c", "llvm") + + +def test_conv2d(target_kind, unpacked_api): RELAY_MODEL = textwrap.dedent( """\ #[version = "0.0.5"] @@ -94,7 +97,7 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), mod = tvm.relay.build( ir_mod, params=params, - target="c", + target=target_kind, executor=backend.Executor("aot", {"unpacked-api": unpacked_api, "interface-api": "packed"}), ) From 5f0a02ee26e4cf466e1346158a00a237cc82b78f Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Fri, 4 Feb 2022 16:22:36 -0800 Subject: [PATCH 39/41] cpptests for Metadata class --- include/tvm/runtime/metadata.h | 9 +-- include/tvm/runtime/metadata_base.h | 59 ++++++-------- src/runtime/metadata.cc | 18 +---- src/target/metadata.h | 6 +- tests/cpp/test_metadata.cc | 114 ++++++++++++++++++++++++++-- 5 files changed, 143 insertions(+), 63 deletions(-) diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h index c29b701291b0..3fab610ef055 100644 --- a/include/tvm/runtime/metadata.h +++ b/include/tvm/runtime/metadata.h @@ -94,9 +94,6 @@ class MetadataNode : public MetadataBaseNode { private: const struct ::TVMMetadata* data_; - ::std::shared_ptr<::std::vector> inputs_refs_; - ::std::shared_ptr<::std::vector> outputs_refs_; - ::std::shared_ptr<::std::vector<::tvm::runtime::String>> devices_refs_; }; class Metadata : public MetadataBase { @@ -112,9 +109,9 @@ class TensorInfoNode : public MetadataBaseNode { std::string get_name() override; inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); } inline int64_t num_shape() const { return data_->num_shape; } - inline ::tvm::support::Span shape() const { - return ::tvm::support::Span(data_->shape, - data_->shape + data_->num_shape); + inline ::tvm::support::Span shape() const { + return ::tvm::support::Span(data_->shape, + data_->shape + data_->num_shape); } inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); } const struct ::TVMTensorInfo* data() const { return data_; } diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h index 7229af58fbef..b707bdf68a96 100644 --- a/include/tvm/runtime/metadata_base.h +++ b/include/tvm/runtime/metadata_base.h @@ -55,7 +55,9 @@ class ArrayAccessor; template class ArrayIterator { public: - ArrayIterator(size_t index, ArrayAccessor* parent) : index_{index}, parent_{parent} {} + using value_type = Ref; + + ArrayIterator(size_t index, const ArrayAccessor* parent) : index_{index}, parent_{parent} {} inline Ref operator*() { return (*parent_)[index_]; } @@ -75,81 +77,68 @@ class ArrayIterator { private: size_t index_; - ArrayAccessor* parent_; + const ArrayAccessor* parent_; }; template class ArrayAccessor { public: + using value_type = Ref; + using iterator = ArrayIterator; + using const_iterator = ArrayIterator; + template ::value>::type> - ArrayAccessor(const C* data, size_t num_data, ::std::shared_ptr<::std::vector> refs) - : data_{data}, num_data_{num_data}, refs_{refs} {} + ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {} - inline size_t size() { return num_data_; } + inline size_t size() const { return num_data_; } - inline Ref operator[](size_t index) { + inline Ref operator[](size_t index) const { if (index >= num_data_) { throw std::runtime_error("Index out of range"); } - if (refs_->size() <= index) { - refs_->resize(num_data_); - } - - if (!(*refs_)[index].defined()) { - (*refs_)[index] = Ref(&data_[index]); - } - - return (*refs_)[index]; + return Ref(&data_[index]); } - inline ArrayIterator begin() { return ArrayIterator{0, this}; } + inline ArrayIterator begin() const { return ArrayIterator{0, this}; } - inline ArrayIterator end() { return ArrayIterator{num_data_, this}; } + inline ArrayIterator end() const { return ArrayIterator{num_data_, this}; } private: const C* data_; size_t num_data_; - ::std::shared_ptr<::std::vector> refs_; }; template <> class ArrayAccessor { public: - ArrayAccessor(const char** data, size_t num_data, - ::std::shared_ptr> refs) - : data_{data}, num_data_{num_data}, refs_{refs} {} + using value_type = ::tvm::runtime::String; + using iterator = ArrayIterator; + using const_iterator = ArrayIterator; - inline size_t size() { return num_data_; } + ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {} - inline ::tvm::runtime::String operator[](size_t index) { + inline size_t size() const { return num_data_; } + + inline ::tvm::runtime::String operator[](size_t index) const { if (index >= num_data_) { throw std::runtime_error("Index out of range"); } - if (refs_->size() <= index) { - refs_->resize(num_data_); - } - - if (!(*refs_)[index].defined()) { - (*refs_)[index] = ::tvm::runtime::String(data_[index]); - } - - return (*refs_)[index]; + return ::tvm::runtime::String(data_[index]); } - inline ArrayIterator begin() { + inline ArrayIterator begin() const { return ArrayIterator{0, this}; } - inline ArrayIterator end() { + inline ArrayIterator end() const { return ArrayIterator{num_data_, this}; } private: const char** data_; size_t num_data_; - ::std::shared_ptr<::std::vector<::tvm::runtime::String>> refs_; }; enum MetadataTypeIndex : uint8_t { diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index 8415b5742dcd..021e8244d5bf 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -42,25 +42,13 @@ TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode); ArrayAccessor MetadataNode::inputs() { - if (inputs_refs_.get() == nullptr) { - inputs_refs_.reset(new ::std::vector()); - } - return ArrayAccessor(data_->inputs, data_->num_inputs, - inputs_refs_); + return ArrayAccessor(data_->inputs, data_->num_inputs); } ArrayAccessor MetadataNode::outputs() { - if (outputs_refs_.get() == nullptr) { - outputs_refs_.reset(new ::std::vector()); - } - return ArrayAccessor(data_->outputs, data_->num_outputs, - outputs_refs_); + return ArrayAccessor(data_->outputs, data_->num_outputs); } ArrayAccessor MetadataNode::devices() { - if (devices_refs_.get() == nullptr) { - devices_refs_.reset(new ::std::vector<::tvm::runtime::String>()); - } - return ArrayAccessor(data_->devices, data_->num_devices, - devices_refs_); + return ArrayAccessor(data_->devices, data_->num_devices); } Metadata::Metadata(const struct ::TVMMetadata* data) : MetadataBase{make_object(data)} {} diff --git a/src/target/metadata.h b/src/target/metadata.h index 90b0c3c5cfea..adab4810f3d5 100644 --- a/src/target/metadata.h +++ b/src/target/metadata.h @@ -46,11 +46,13 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { auto inputs_accessor = inputs(); inputs_array.reserve(num_inputs()); for (int64_t i = 0; i < num_inputs(); ++i) { - inputs_array.push_back(::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]}); + auto ti = ::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]}; + inputs_array.push_back(ti); } - ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{inputs_array, + ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{::std::move(inputs_array), runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + Downcast<::tvm::runtime::metadata::TensorInfo>(inputs_metadata_array->array[0]); v->Visit("inputs", &inputs_metadata_array); auto outputs_array = Array(); auto outputs_accessor = outputs(); diff --git a/tests/cpp/test_metadata.cc b/tests/cpp/test_metadata.cc index be8c4fb8be62..a4afdc40e2e5 100644 --- a/tests/cpp/test_metadata.cc +++ b/tests/cpp/test_metadata.cc @@ -19,6 +19,8 @@ #include #include +#include +#include #include namespace { @@ -45,6 +47,7 @@ const struct TVMMetadata kNormal = { }; } +using ::tvm::runtime::Downcast; using ::testing::Eq; using ::testing::ElementsAre; @@ -57,11 +60,112 @@ TEST(Metadata, ParseStruct) { EXPECT_THAT(input1->name(), Eq("input1")); EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); EXPECT_THAT(input1->dtype(), Eq(tvm::runtime::DataType(DLDataType{1, 2, 3}))); - // auto md_inputs = md->inputs(); - // EXPECT_EQ(md_inputs.size(), 1); - // auto md_input = md_inputs[0]; + EXPECT_THAT(md->num_outputs(), Eq(1)); + auto output1 = md->outputs()[0]; + EXPECT_THAT(output1->name(), Eq("output1")); + EXPECT_THAT(::std::vector(output1->shape()), ElementsAre(3, 8, 8)); + EXPECT_THAT(output1->dtype(), Eq(tvm::runtime::DataType(DLDataType{3, 4, 5}))); - // EXPECT_EQ(md->get_name(), kNormal.name); -// EXPECT_EQ(md->get_name(), kNormal.name); + auto devices = md->devices(); + EXPECT_THAT(devices, ElementsAre(::tvm::runtime::String("device1"), + ::tvm::runtime::String("device2"))); + + EXPECT_THAT(md->executor(), Eq("aot")); + EXPECT_THAT(md->mod_name(), Eq("default")); + EXPECT_THAT(md->interface_api(), Eq("c")); + EXPECT_THAT(md->use_unpacked_api(), Eq(true)); +} + +class TestVisitor : public tvm::AttrVisitor { + public: + using Element = ::std::tuple<::std::string, ::tvm::runtime::ObjectRef>; + void Visit(const char* key, double* value) final { + keys.push_back(key); + values.push_back(::tvm::FloatImm(::tvm::runtime::DataType(kDLFloat, 64, 1), *value)); + } + void Visit(const char* key, int64_t* value) final { + keys.push_back(key); + values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLInt, 64, 1), *value)); + } + void Visit(const char* key, uint64_t* value) final { + keys.push_back(key); + int64_t v; + *(reinterpret_cast(&v)) = *value; + values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLUInt, 64, 1), v)); + } + void Visit(const char* key, int* value) final { + keys.push_back(key); + values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLInt, 64, 1), *value)); + } + void Visit(const char* key, bool* value) final { + keys.push_back(key); + values.push_back(::tvm::Bool(*value)); + } + void Visit(const char* key, std::string* value) final { + keys.push_back(key); + values.push_back(::tvm::runtime::String(*value)); + } + void Visit(const char* key, tvm::runtime::DataType* value) final { + keys.push_back(key); + values.push_back(::tvm::PrimType(*value)); + } + void Visit(const char* key, tvm::runtime::NDArray* value) final { + keys.push_back(key); + values.push_back(*value); + } + void Visit(const char* key, void** value) final { + CHECK(false) << "Do not expect this type"; + } + + void Visit(const char* key, ::tvm::runtime::ObjectRef* value) final { + keys.push_back(key); + values.push_back(*value); + } + + std::vector keys; + std::vector<::tvm::runtime::ObjectRef> values; +}; + +TEST(Metadata, Visitor) { + tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + TestVisitor v; + ::tvm::ReflectionVTable::Global()->VisitAttrs(md.operator->(), &v); + + EXPECT_THAT(v.keys, ElementsAre( + Eq("version"), + Eq("inputs"), + Eq("outputs"), + Eq("devices"), + Eq("executor"), + Eq("mod_name"), + Eq("interface_api"), + Eq("use_unpacked_api") + )); + + EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); + + // Just identify the tensor. + auto input_array = Downcast(v.values[1]); + EXPECT_THAT(input_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); + EXPECT_THAT(input_array->struct_name, Eq(std::string("TVMTensorInfo"))); + EXPECT_THAT(input_array->array.size(), Eq(1)); + auto array0 = input_array->array[0]; + + auto input1 = Downcast(array0); + EXPECT_THAT(input1->name(), Eq("input1")); + + auto output_array = Downcast(v.values[2]); + EXPECT_THAT(output_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); + EXPECT_THAT(output_array->struct_name, Eq("TVMTensorInfo")); + auto output1 = Downcast(output_array->array[0]); + + EXPECT_THAT(output1->name(), Eq("output1")); + + auto devices = Downcast(v.values[3]); + EXPECT_THAT(devices->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kString)); + EXPECT_THAT(Downcast(devices->array[0]), Eq("device1")); + EXPECT_THAT(Downcast(devices->array[1]), Eq("device1")); + +// EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); } From bb604e7e44ff40c51712f63efcb66c03ac440716 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 10 Feb 2022 18:00:05 -0800 Subject: [PATCH 40/41] checkpoint --- src/target/llvm/codegen_llvm.cc | 32 +++++++++++++++++++++++++------- src/target/llvm/codegen_llvm.h | 9 ++++++++- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index d1adc203749a..af90be20f863 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -190,6 +190,18 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } +llvm::GlobalVariable* CodeGenLLVM::GetLinkedParamSymbol(const std::string& param_name, llvm::ConstantArray* array) { + std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + param_name; + llvm::GlobalVariable* var = module_->getGlobalVariable(symbol_name); + if (var == nullptr) { + var = new llvm::GlobalVariable( + *module_, t_void_p_, true, llvm::GlobalValue::CommonLinkage, nullptr, symbol_name); + } + //(array != nullptr ? static_cast(array->getType()) : static_cast(t_void_p_)), + return var; +} + + void CodeGenLLVM::LinkParameters(const Map params) { // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc, // but they are at a different layer in the compiler... @@ -235,9 +247,7 @@ void CodeGenLLVM::LinkParameters(const Map params) { // Add data to the global section. for (auto kv : params) { auto array = NDArrayToLLVMArray(ctx_, kv.second->param); - std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first; - llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( - *module_, array->getType(), true, llvm::GlobalValue::ExternalLinkage, array, symbol_name); + llvm::GlobalVariable* param_symbol = GetLinkedParamSymbol(kv.first, array); auto dtype = tvm::runtime::DataType(kv.second->param->dtype); size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment); #if TVM_LLVM_VERSION >= 100 @@ -245,8 +255,9 @@ void CodeGenLLVM::LinkParameters(const Map params) { #else param_symbol->setAlignment(align); #endif + param_symbol->setInitializer(array); - llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function); + llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + param_symbol->getName(), function); switch_inst->addCase( llvm::cast(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block); builder_->SetInsertPoint(case_block); @@ -387,6 +398,10 @@ void CodeGenLLVM::Optimize() { fpass.run(*it); } fpass.doFinalization(); + std::string tmp; + llvm::raw_string_ostream stream(tmp); + module_->print(stream, nullptr); + LOG(INFO) << "LLVM IR: " << stream.str(); mpass.run(*module_); } @@ -1266,8 +1281,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { return this->CreateCallExtern(GetType(GetRef(op)), global_symbol->value, op->args, true); } else if (op->op.same_as(builtin_lookup_param_)) { - std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + Downcast(op->args[0])->value; //operator std::string(); - return module_->getGlobalVariable(symbol_name); +// return llvm::ConstantInt::get(t_void_p_, 0); + return GetLinkedParamSymbol(Downcast(op->args[0])->value, nullptr); } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic ICHECK_GE(op->args.size(), 1U); @@ -1279,7 +1294,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { return this->CreateCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], op->args, false); } else { - return CreateIntrinsic(op); + VLOG(2) << "CreateIntrinsic: " << GetRef(op); + auto x = CreateIntrinsic(op); + VLOG(2) << "CreateIntrinsic done"; + return x; } } else { ICHECK(op->op.as()); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 52b94618caa3..53721ae80e41 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -141,7 +141,7 @@ class CodeGenLLVM : public ExprFunctor, * \param e The expression to be created value for. * \return created value. */ - llvm::Value* MakeValue(const PrimExpr& e) { return VisitExpr(e); } + llvm::Value* MakeValue(const PrimExpr& e) { auto a = VisitExpr(e); LOG(INFO) << "MakeValue (" << e << "): " << a; return a; } // Short hande code to get a constant int 32 llvm::Constant* ConstInt32(int64_t value) const { return llvm::ConstantInt::getSigned(t_int32_, value); @@ -291,6 +291,13 @@ class CodeGenLLVM : public ExprFunctor, */ llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, llvm::ArrayRef arg_types); + /*! + * \brief Lookup or create a GlobalVariable whose content is the data field of a DLTensor for a + * given linked_param() CallNode. + * \param param_name Parameter name (e.g. unmangled, from lookup_param node). + * \return the GlobalVariable indicated in the brief. + */ + llvm::GlobalVariable* GetLinkedParamSymbol(const ::std::string& param_name, llvm::ConstantArray* array); /*! * \brief Get the number of elements in the given vector value. * \param vec The value, must be of a vector type. From cd1de42e8185493fba751e556090f77bcc97b2de Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 15 Feb 2022 16:36:08 -0800 Subject: [PATCH 41/41] latest llvm for masa --- src/relay/backend/aot_executor_codegen.cc | 11 ++++-- src/target/llvm/codegen_cpu.cc | 47 +++++++++++++++++------ src/target/llvm/codegen_cpu.h | 4 +- src/target/llvm/codegen_llvm.cc | 14 +++---- src/target/llvm/codegen_llvm.h | 2 +- src/target/llvm/llvm_module.cc | 6 +-- src/tir/transforms/lower_tvm_builtin.cc | 4 +- 7 files changed, 56 insertions(+), 32 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 7486a683e186..a1b86f32c2b5 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -175,9 +175,9 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { private: void AssignReturnSid(Expr e) { if (storage_device_map_.find(e) != storage_device_map_.end()) { - LOG(INFO) << "AssignReturnSid: is now " << e; +// LOG(INFO) << "AssignReturnSid: is now " << e; StorageInfo& sinfo = storage_device_map_[e]; - LOG(INFO) << "AssignReturnSid: storage_device_map_ " << sinfo; +// LOG(INFO) << "AssignReturnSid: storage_device_map_ " << sinfo; return_ids_.clear(); for (auto sid : sinfo->storage_ids) { return_ids_.push_back(sid); @@ -337,7 +337,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { * a DLTensor on stack. */ PrimExpr MakeDLTensor(Expr relay_arg, TensorType ttype, PrimExpr data) { - LOG(INFO) << "MakeDLTensor: " << relay_arg << " (ttype " << ttype << "): " << data; +// LOG(INFO) << "MakeDLTensor: " << relay_arg << " (ttype " << ttype << "): " << data; return data; } // for (Var v : input_vars_) { @@ -471,7 +471,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { })); } else if (use_call_cpacked_ && !use_unpacked_api_) { // call_cpacked calling convention needs a blank context - args.push_back(tir::make_zero(DataType::Handle())); + // TOOD only c runtime +// args.push_back(tir::make_zero(DataType::Handle())); tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); create_func_call_stmts.push_back(func_call); } else { @@ -971,8 +972,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Legalize AOT if needed. This means that all the packed calls // need to be wrapped in TVMValues (unless use_unpacked_api is set) if (!use_unpacked_api_) { + LOG(INFO) << "Legalize Packed " << mod_run; auto pack_calls = tir::transform::LegalizePackedCalls(); mod_run = pack_calls(mod_run); + LOG(INFO) << "Legalize Packed done " << mod_run; } ret.function_metadata = std::move(function_metadata_); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index ca813544f0b8..702aa3e38495 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -74,7 +74,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, // void* resource_handle); ftype_tvm_backend_packed_c_func_ = llvm::FunctionType::get( t_int_, - {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, + {t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_void_p_}, false); t_tvm_crt_func_registry_ = llvm::StructType::create( @@ -795,10 +795,13 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& args, const DataType& r_type, - const int64_t begin, const int64_t end) { + const int64_t begin, const int64_t end, bool use_string_lookup) { PackedCall pc; std::string func_name = args[0].as()->value; - llvm::Value* handle = GetPackedFuncHandle(func_name); + llvm::Value* handle = nullptr; + if (use_string_lookup) { + handle = GetPackedFuncHandle(func_name); + } // call the function int64_t nargs = end - begin; ICHECK_GE(nargs, 0); @@ -813,14 +816,31 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& ConstInt32(end)); TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + llvm::FunctionType* callee_ftype = nullptr; + llvm::Value* callee_value = nullptr; + std::vector call_args; + if (use_string_lookup) { + callee_ftype = ftype_tvm_func_call_; + callee_value = RuntimeTVMFuncCall(); + call_args.push_back(handle); + } else { + callee_ftype = ftype_tvm_backend_packed_c_func_; + callee_value = module_->getFunction(func_name); + if (callee_value == nullptr) { + callee_value = llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, func_name, module_.get()); + } + } + call_args.insert(call_args.end(), {arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + if (!use_string_lookup) { + call_args.push_back(llvm::ConstantPointerNull::get(t_void_p_)); + } #if TVM_LLVM_VERSION >= 90 - auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); + auto call_callee = llvm::FunctionCallee(callee_ftype, callee_value); #else - auto call_callee = RuntimeTVMFuncCall(); + auto call_callee = callee_value; #endif - llvm::Value* call = builder_->CreateCall( - call_callee, - {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + llvm::Value* call = builder_->CreateCall(call_callee, call_args); + llvm::BasicBlock* end_block = CheckCallSuccess(call); // Load the return value and cast it to the designated type (r_type). @@ -849,17 +869,18 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& return pc; } -llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { +llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lookup) { + LOG(INFO) << "CreateCallPacked: " << GetRef(op); ICHECK_EQ(op->args.size(), 5U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + op->args[4].as()->value, use_string_lookup); return pc.ret_value; } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { ICHECK_EQ(op->args.size(), 6U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + op->args[4].as()->value, true); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. @@ -1216,9 +1237,11 @@ void CodeGenCPU::AddStartupFunction() { llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin::tvm_call_packed_lowered())) { - return CreateCallPacked(op); + return CreateCallPacked(op, true /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { return CreateCallTracePacked(op); + } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { + return CreateCallPacked(op, false /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_static_handle())) { return CreateStaticHandle(); } else if (op->op.same_as(builtin::tvm_throw_last_error())) { diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 6b70d65b3747..af56088125d5 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -121,9 +121,9 @@ class CodeGenCPU : public CodeGenLLVM { llvm::BasicBlock* end_block; }; PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, - const int64_t begin, const int64_t end); + const int64_t begin, const int64_t end, bool use_string_lookup); // create call into tvm packed function. - llvm::Value* CreateCallPacked(const CallNode* op); + llvm::Value* CreateCallPacked(const CallNode* op, bool use_string_lookup); // Create trace call into tvm packed function. llvm::Value* CreateCallTracePacked(const CallNode* op); // Create static initialization diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index af90be20f863..510e89eee1ca 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -194,10 +194,11 @@ llvm::GlobalVariable* CodeGenLLVM::GetLinkedParamSymbol(const std::string& param std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + param_name; llvm::GlobalVariable* var = module_->getGlobalVariable(symbol_name); if (var == nullptr) { + CHECK(array != nullptr) << "Expect param symbol " << symbol_name << " to either be defined or for the array to be supplied"; var = new llvm::GlobalVariable( - *module_, t_void_p_, true, llvm::GlobalValue::CommonLinkage, nullptr, symbol_name); + *module_, static_cast(array->getType()), true, llvm::GlobalValue::CommonLinkage, array, symbol_name); } - //(array != nullptr ? static_cast(array->getType()) : static_cast(t_void_p_)), + //(array != nullptr ? : static_cast(t_void_p_)), return var; } @@ -401,7 +402,7 @@ void CodeGenLLVM::Optimize() { std::string tmp; llvm::raw_string_ostream stream(tmp); module_->print(stream, nullptr); - LOG(INFO) << "LLVM IR: " << stream.str(); +// LOG(INFO) << "LLVM IR: " << stream.str(); mpass.run(*module_); } @@ -1274,13 +1275,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { +// LOG(INFO) << "Visit Call:" << GetRef(op); if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); - if (op->op.same_as(builtin_tvm_call_cpacked_lowered_)) { - auto global_symbol = Downcast(op->args[0]); - return this->CreateCallExtern(GetType(GetRef(op)), global_symbol->value, op->args, - true); - } else if (op->op.same_as(builtin_lookup_param_)) { + if (op->op.same_as(builtin_lookup_param_)) { // return llvm::ConstantInt::get(t_void_p_, 0); return GetLinkedParamSymbol(Downcast(op->args[0])->value, nullptr); } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 53721ae80e41..81828735d10b 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -141,7 +141,7 @@ class CodeGenLLVM : public ExprFunctor, * \param e The expression to be created value for. * \return created value. */ - llvm::Value* MakeValue(const PrimExpr& e) { auto a = VisitExpr(e); LOG(INFO) << "MakeValue (" << e << "): " << a; return a; } + llvm::Value* MakeValue(const PrimExpr& e) { auto a = VisitExpr(e); /* LOG(INFO) << "MakeValue (" << e << "): " << a; */ return a; } // Short hande code to get a constant int 32 llvm::Constant* ConstInt32(int64_t value) const { return llvm::ConstantInt::getSigned(t_int32_, value); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index a7cc863cbedd..5ddeb18c9ee0 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -308,14 +308,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { cg->SetFastMathFlag(fmf); + if (found_linked_params) { + cg->LinkParameters(linked_params); + } cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } - if (found_linked_params) { - cg->LinkParameters(linked_params); - } module_ = cg->Finish(); module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, LLVMTargetToString(target))); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index a5ecf4ba8296..f8dca1a4f7c6 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -328,8 +328,8 @@ class BuiltinLower : public StmtExprMutator { // cpacked call resource_handle if (!use_string_lookup) { - tir::Var resource_handle = Downcast(op->args[arg_count]); - packed_args.push_back(StringImm(resource_handle->name_hint)); +// tir::Var resource_handle = Downcast(op->args[arg_count]); +// packed_args.push_back(StringImm(resource_handle->name_hint)); } auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered()