-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Convert AOT to TECompiler #8697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,15 +38,14 @@ | |
| #include <string> | ||
| #include <vector> | ||
|
|
||
| #include "compile_engine.h" | ||
| #include "te_compiler.h" | ||
| #include "utils.h" | ||
|
|
||
| namespace tvm { | ||
| namespace relay { | ||
| namespace backend { | ||
|
|
||
| using IntegerArray = Array<Integer>; | ||
| using TargetsMap = std::unordered_map<int, Target>; | ||
| using StorageMap = | ||
| std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>; | ||
|
|
||
|
|
@@ -287,7 +286,6 @@ class AOTExecutorCodegen : public ExprVisitor { | |
| void CreateFuncCall(Call call, std::string func_name) { | ||
| tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)}; | ||
| std::vector<tir::Stmt> create_func_call_stmts; | ||
|
|
||
| // Pack the inputs | ||
| for (Expr arg : call->args) { | ||
| if (params_by_expr_.find(arg) != params_by_expr_.end()) { | ||
|
|
@@ -365,155 +363,21 @@ class AOTExecutorCodegen : public ExprVisitor { | |
| return ss.str(); | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Update the "main" control function's metadata | ||
| * | ||
| * \param func The main function that contains calls to operator tir primitive functions | ||
| */ | ||
| void UpdateMainWorkspaceSize(const tir::PrimFunc& primfunc, const relay::Function& func) { | ||
| auto workspace_byte_alignment = target_host_->GetAttr<Integer>("workspace-byte-alignment") | ||
| .value_or(tvm::runtime::kDefaultWorkspaceAlignment); | ||
| Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); | ||
| // Populate FunctionInfo | ||
| auto fi_node = make_object<FunctionInfoNode>(); | ||
| // Initialize all target workspaces to zero | ||
| for (const auto& kv : targets_) { | ||
| auto tgt = kv.second; | ||
| fi_node->workspace_sizes.Set(tgt, 0); | ||
| } | ||
| fi_node->workspace_sizes.Set(target_host_, workspace_size); | ||
| fi_node->relay_primfuncs.Set(target_host_, func); | ||
|
|
||
| int64_t io_size = 0; | ||
| for (const auto& input : input_vars_) { | ||
| io_size += CalculateRelayExprSizeBytes(input->checked_type()); | ||
| } | ||
| io_size += CalculateRelayExprSizeBytes(func->body->checked_type()); | ||
| fi_node->io_sizes.Set(target_host_, io_size); | ||
|
|
||
| int64_t const_size = 0; | ||
| for (const auto& kv : params_by_expr_) { | ||
| const_size += CalculateRelayExprSizeBytes(kv.first->checked_type()); | ||
| } | ||
| fi_node->constant_sizes.Set(target_host_, const_size); | ||
| function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node)); | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Update the function metadata for a given cached function and its relay | ||
| * primitive function. | ||
| * | ||
| * \param cfunc The cached function as provided the by the compile engine | ||
| * \param relay_func The source relay primitive function | ||
| * \param relay_target The target associated with relay primitive function | ||
| */ | ||
| void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func, | ||
| const Target& relay_target) { | ||
| auto fi_node = make_object<FunctionInfoNode>(); | ||
| for (const auto& kv : cfunc->funcs->functions) { | ||
| auto primfunc = Downcast<tir::PrimFunc>(kv.second); | ||
| auto workspace_byte_alignment = | ||
| target_host_->GetAttr<Integer>("workspace-byte-alignment").value_or(16); | ||
| Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); | ||
| Target primfunc_target = relay_target; | ||
| if (primfunc->attrs->dict.count("target")) { | ||
| primfunc_target = Downcast<Target>(primfunc->attrs->dict["target"]); | ||
| } | ||
| fi_node->workspace_sizes.Set(primfunc_target, workspace_size); | ||
| // Calculating size for I/O | ||
| for (auto const& param : primfunc->params) { | ||
| auto p_shape = primfunc->buffer_map[param]->shape; | ||
| int num_of_elements = 1; | ||
| for (const auto& dim_index_expr : p_shape) { | ||
| if (dim_index_expr->IsInstance<IntImmNode>()) { | ||
| num_of_elements *= dim_index_expr.as<IntImmNode>()->value; | ||
| } else { | ||
| // If shape is dynamic, we cannot calculate workspace in compile time. | ||
| num_of_elements = 0; | ||
| } | ||
| } | ||
| int element_size = primfunc->buffer_map[param]->dtype.bytes(); | ||
| fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements); | ||
| } | ||
| fi_node->constant_sizes.Set(primfunc_target, 0); | ||
| fi_node->tir_primfuncs.Set(primfunc_target, primfunc); | ||
| fi_node->relay_primfuncs.Set(primfunc_target, relay_func); | ||
| } | ||
| function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node)); | ||
| } | ||
|
|
||
| void VisitExpr_(const CallNode* op) override { | ||
| // Descend the call tree | ||
| for (auto arg : op->args) { | ||
| VisitExpr(arg); | ||
| } | ||
|
|
||
| Expr expr = GetRef<Expr>(op); | ||
| Function func; | ||
| if (op->op.as<OpNode>()) { | ||
| LOG(FATAL) << "Operators should be transformed away; try applying" | ||
| << "the fuse_ops transformation to the expression."; | ||
| } else if (op->op.as<GlobalVarNode>()) { | ||
| LOG(FATAL) << "Not implemented"; | ||
| } else if (op->op.as<FunctionNode>()) { | ||
| func = GetRef<Function>(op->op.as<FunctionNode>()); | ||
| GlobalVar node = GetRef<GlobalVar>(op->op.as<GlobalVarNode>()); | ||
| CreateFuncCall(GetRef<Call>(op), node->name_hint); | ||
| } else { | ||
| LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); | ||
| } | ||
| if (!func->HasNonzeroAttr(attr::kPrimitive)) { | ||
| LOG(FATAL) << "TVM only support calls to primitive functions " | ||
| << "(i.e functions composed of fusable operator invocations)"; | ||
| } | ||
|
|
||
| Target target; | ||
|
|
||
| // Handle external function | ||
| if (func->GetAttr<String>(attr::kCompiler).defined()) { | ||
| target = Target("ext_dev"); | ||
| CCacheKey key = CCacheKey(func, target); | ||
| CachedFunc ext_func = compile_engine_->Lower(key, mod_name_); | ||
| ICHECK(ext_func.defined()) << "External function is not defined."; | ||
| UpdateConstants(func, ¶ms_); | ||
|
|
||
| // Generate the TIR function call | ||
| CreateFuncCall(GetRef<Call>(op), ext_func->prim_fn_var->name_hint); | ||
| return; | ||
| } | ||
|
|
||
| ICHECK_GE(storage_device_map_.count(expr), 0); | ||
| StorageInfo& sinfo = storage_device_map_[expr]; | ||
| auto call_dev_type = sinfo->device_types[0]; | ||
| // Normal Relay Function | ||
| if (targets_.size() == 1) { | ||
| // homogeneous execution. | ||
| const auto& it = targets_.begin(); | ||
| target = (*it).second; | ||
| } else { | ||
| // heterogeneous execution. | ||
| std::string call_dev_name; | ||
| if (call_dev_type == 0) { | ||
| call_dev_name = "llvm"; | ||
| } else { | ||
| call_dev_name = runtime::DeviceName(call_dev_type); | ||
| } | ||
| if (targets_.count(call_dev_type) == 0) { | ||
| LOG(FATAL) << "No target is provided for device " << call_dev_name; | ||
| } | ||
| target = targets_[call_dev_type]; | ||
| } | ||
|
|
||
| CCacheKey key = CCacheKey(func, target); | ||
| CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_); | ||
|
|
||
| if (!lowered_funcs_.count(target->str())) { | ||
| lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({})); | ||
| } | ||
| lowered_funcs_[target->str()]->Update(lowered_func->funcs); | ||
| // Update function metadata via looking at all primfuncs | ||
| UpdateFunctionMetadata(lowered_func, func, target); | ||
|
|
||
| // Generate the TIR function call | ||
| CreateFuncCall(GetRef<Call>(op), lowered_func->prim_fn_var->name_hint); | ||
| } | ||
|
|
||
| void VisitExpr_(const VarNode* op) override { | ||
|
|
@@ -598,7 +462,7 @@ class AOTExecutorCodegen : public ExprVisitor { | |
| // Create the main PrimFunc to execute the graph. Please note that | ||
| // the packed function calls don't pack their arguments. The AOT | ||
| // runner function needs to be legalized by the LegalizePackedCalls pass. | ||
| tir::PrimFunc CreateMainFunc(unsigned int relay_params) { | ||
| tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) { | ||
| tir::Stmt body = tir::SeqStmt(stmts_); | ||
|
|
||
| // Allocate the sids | ||
|
|
@@ -637,7 +501,7 @@ class AOTExecutorCodegen : public ExprVisitor { | |
| // Define the PrimFunc attributes | ||
| Map<String, ObjectRef> dict_attrs; | ||
| String run_func_name = | ||
| runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix); | ||
| runtime::get_name_mangled(mod_name, runtime::symbol::tvm_run_func_suffix); | ||
| dict_attrs.Set("global_symbol", run_func_name); | ||
| dict_attrs.Set("runner_function", Bool(true)); | ||
|
|
||
|
|
@@ -654,7 +518,7 @@ class AOTExecutorCodegen : public ExprVisitor { | |
| /*! \brief input and output variables belonging to the main function signature */ | ||
| Array<tir::Var> main_signature_; | ||
| /*! \brief target device */ | ||
| TargetsMap targets_; | ||
| tec::TargetMap targets_; | ||
| /*! \brief target host */ | ||
| Target target_host_; | ||
| /*! | ||
|
|
@@ -684,35 +548,70 @@ class AOTExecutorCodegen : public ExprVisitor { | |
| /*! \brief mapping sid -> tir::Var */ | ||
| std::unordered_map<int, te::Var> sids_table_; | ||
| /*! \brief lowered funcs */ | ||
| std::unordered_map<std::string, IRModule> lowered_funcs_; | ||
| /*! \brief lowered funcs */ | ||
| Map<String, FunctionInfo> function_metadata_; | ||
| /*! \brief compile engine */ | ||
| CompileEngine compile_engine_; | ||
| /*! \brief the set of statements that make the program */ | ||
| std::vector<tir::Stmt> stmts_; | ||
| /*! \brief the list of return sids (note that the function might return more then one output */ | ||
| std::vector<int> return_sid_; | ||
| /*! \brief the module name we use to mangle the function names */ | ||
| String mod_name_; | ||
|
|
||
| public: | ||
| AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host) | ||
| AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) | ||
| : mod_(mod), | ||
| targets_(targets), | ||
| target_host_(target_host), | ||
| use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))), | ||
| compile_engine_(CompileEngine::Global()) {} | ||
| use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {} | ||
|
|
||
| LoweredOutput Codegen(relay::Function func, String mod_name) { | ||
| auto aot_allocator = AOTOnDemandAllocator(); | ||
| aot_allocator.Run(func); | ||
|
|
||
| // Retrieve the storage map | ||
| storage_device_map_ = aot_allocator.GetStorageMap(); | ||
| mod_name_ = mod_name; | ||
| // Pre-lowering storage map and memory plan | ||
| StorageMap initial_storage_map = aot_allocator.GetStorageMap(); | ||
| StaticMemoryPlan memory_plan(initial_storage_map); | ||
|
|
||
| // Build a map from each operation to device. | ||
| tec::DeviceMap device_context_map; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mbs-octoml I don't think @Mousius needs to do this in this patch, but this is where we should split device planning and storage planning. I think we can remove the need to pre-storage plan at all if we can obtain the device information pre-lowering, then storage plan after the lowering.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. |
||
| for (const auto& it : memory_plan->expr_to_storage_info) { | ||
| auto expr = it.first; | ||
| auto storage_info = it.second; | ||
| auto device_types = storage_info->device_types; | ||
| // CHECK_EQ(device_types.size(), 1); | ||
| tvm::Device dev; | ||
| dev.device_id = 0; | ||
| dev.device_type = device_types[0]; | ||
| device_context_map.insert({expr, dev}); | ||
| } | ||
|
|
||
| // This first phase moves from implicit use of compile engine, | ||
| // to instead explicitly lowering the incoming IRModule, and then | ||
| // performing the preexisting AOT executor code generation phase. | ||
| IRModule mod = IRModule::FromExpr(func); | ||
| auto lowered_module = tec::LowerTE( | ||
| mod, targets_, device_context_map, memory_plan, mod_name, [this](Function func) { | ||
| // We need to maintain the constant map for external | ||
| // functions so we pass this processing function which | ||
| // allows us to process each function as we lower it. | ||
| if (func->GetAttr<String>(attr::kCompiler).defined()) { | ||
| UpdateConstants(func, ¶ms_); | ||
| } | ||
|
|
||
| // TODO(@areusch, @jroesch): We should refactor this to | ||
| // execute as a further pass, instead writing data to the | ||
| // lowering process directly. | ||
| tec::UpdateFunctionMetadata(func, this->function_metadata_); | ||
| }); | ||
|
|
||
| for (auto input : func->params) { | ||
| function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mbs-octoml I think my point is we should be able to do like |
||
| auto lowered_main = lowered_module.main_module->Lookup("main"); | ||
| auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>()); | ||
|
|
||
| // Post-lowering storage map for writing main func - this should be the same map as previously | ||
| // created, just referencing the new expressions created from lowering | ||
| auto new_allocator = AOTOnDemandAllocator(); | ||
| new_allocator.Run(lowered_main_func); | ||
| storage_device_map_ = new_allocator.GetStorageMap(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feel free to leave a TODO(mbs) to remove this reconstruction since I'm trying to replace these Expr->Storage and Expr->Device side maps with attrs.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's ok without an explicit TODO here as it makes sense to replan the allocations? I appreciate it'll get improved later 😸
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, but then can you comment the storage map should be morally the same as the original, just with the keys updated to follow along with the rewritten primitive calls. Or at least that's what I think should be happening, is that right?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yip, that's correct, I've updated the comment to clarify that - what do you think now?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lgtm |
||
|
|
||
| for (auto input : lowered_main_func->params) { | ||
| input_vars_.push_back(input); | ||
| main_signature_.push_back(tir::Var("input", DataType::Handle())); | ||
| } | ||
|
|
@@ -732,13 +631,12 @@ class AOTExecutorCodegen : public ExprVisitor { | |
| main_signature_.push_back(tir::Var("output", DataType::Handle())); | ||
| } | ||
|
|
||
| VisitExpr(func->body); | ||
| VisitExpr(lowered_main_func->body); | ||
|
|
||
| // Create the runner function. Please note that the function is not legal yet | ||
| // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need | ||
| // to run the LegalizePackedCalls pass. | ||
| auto prim_func = CreateMainFunc(func->params.size()); | ||
| UpdateMainWorkspaceSize(prim_func, func); | ||
| auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); | ||
| LoweredOutput ret; | ||
|
|
||
| ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>(); | ||
|
|
@@ -748,17 +646,7 @@ class AOTExecutorCodegen : public ExprVisitor { | |
| std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second))); | ||
| } | ||
|
|
||
| for (auto& kv : lowered_funcs_) { | ||
| if (ret.lowered_funcs.count(kv.first) == 0) { | ||
| ret.lowered_funcs.Set(kv.first, IRModule(Map<GlobalVar, BaseFunc>({}))); | ||
| } | ||
| auto& mod = ret.lowered_funcs[kv.first]; | ||
| mod->Update(kv.second); | ||
| ret.lowered_funcs.Set(kv.first, mod); | ||
| } | ||
| ret.external_mods = compile_engine_->LowerExternalFunctions(); | ||
|
|
||
| // Build the TIR IRModule | ||
| // Build the TIR IRModule for the AOT function | ||
| Map<GlobalVar, BaseFunc> symbol_map; | ||
| symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); | ||
| IRModule mod_run(symbol_map); | ||
|
|
@@ -774,14 +662,17 @@ class AOTExecutorCodegen : public ExprVisitor { | |
| mod_run = pack_calls(mod_run); | ||
| } | ||
|
|
||
| // Update the lowered functions | ||
| ret.function_metadata = std::move(function_metadata_); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible for us to remove the specialized
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That field is still useful, and we can unify it between Graph and AOT to refactor altogether later; I've just pushed up a change to use the |
||
|
|
||
| ret.lowered_funcs = lowered_module.per_target_module; | ||
| ret.external_mods = lowered_module.external_mods; | ||
|
|
||
| auto target_host_str = target_host_->str(); | ||
| if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { | ||
| ret.lowered_funcs[target_host_str]->Update(mod_run); | ||
| } else { | ||
| ret.lowered_funcs.Set(target_host_str, mod_run); | ||
| } | ||
| ret.function_metadata = std::move(function_metadata_); | ||
|
|
||
| std::vector<String> input_var_names(input_vars_.size()); | ||
| std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(), | ||
|
|
@@ -845,15 +736,15 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { | |
|
|
||
| private: | ||
| void init(void* mod, Map<Integer, tvm::Target> tmp) { | ||
| TargetsMap targets; | ||
| tec::TargetMap targets; | ||
| Target target_host; | ||
| for (const auto& it : tmp) { | ||
| auto dev_type = it.first.as<tir::IntImmNode>(); | ||
| if (!target_host.defined() && it.second->kind->device_type == kDLCPU) { | ||
| target_host = it.second; | ||
| } | ||
| ICHECK(dev_type); | ||
| targets[dev_type->value] = it.second; | ||
| targets[static_cast<DLDeviceType>(dev_type->value)] = it.second; | ||
| } | ||
| codegen_ = std::make_shared<AOTExecutorCodegen>(reinterpret_cast<runtime::Module*>(mod), | ||
| targets, target_host); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should maintain this check -- maybe inside CreateFuncCall ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Post-lowering every function is essentially a
GlobalVarso this path was never called. If there's a test case that shows this I can re-introduce it.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is an assumption turned assertion that all calls are made primitive functions. This guarantees the relay lowering is done, before the respective executor codegen is invoked.
But I see your point -- its was never checked as it should've been. Maybe its worth checking the function attached to GlobalVar has this property?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything is already lowered at this point as it's been through
LowerTEbefore this runs, so we don't have to make the assumption - it's guaranteed 😸I'd also suggest that we don't add defensive code which we can't craft a way to invoke?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see.
So its kind of passed onto LowerTE.
Yeah, if the check itself is not invoked then no point of having it there.
LGTM.