From de2ad603550e3ab0ca8df8ce9172d971822eb32b Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 12 Aug 2021 15:15:07 -0700 Subject: [PATCH 01/12] Initial commit Initial stab at IRModule -> LoweredModule conversion func, notes Add external_mods and main_func_info to conversion funcs MTest lowered module to ir module fix problem with conversion funcs + print stmts Add LowerTE pass Add pLowerTEPass AAdd LowerTEPass to graph_executor_codegen.cc Use LowerTEPass instead of LowerTe in graph_executor_codegen.cc Code cleanup Add docs, more cleanup Formatting --- include/tvm/relay/function.h | 2 + .../relay/backend/graph_executor_codegen.py | 2 +- src/relay/backend/aot_executor_codegen.cc | 19 +-- src/relay/backend/graph_executor_codegen.cc | 18 +-- src/relay/backend/te_compiler.cc | 135 +++++++++++++++++- src/relay/backend/te_compiler.h | 27 +++- 6 files changed, 180 insertions(+), 23 deletions(-) diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index fccd1f937a06..d6a2e06787ff 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -144,6 +144,8 @@ constexpr const char* kComposite = "Composite"; constexpr const char* kInline = "Inline"; /*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; +/*! \brief Indicate the target that the function should be lowered to. */ +constexpr const char* kTarget = "Target"; /*! \brief Mark the function as only composed of reshape operations. */ constexpr const char* kReshapeOnly = "relay.reshape_only"; diff --git a/python/tvm/relay/backend/graph_executor_codegen.py b/python/tvm/relay/backend/graph_executor_codegen.py index 58717a0ab482..a5d3956fc02b 100644 --- a/python/tvm/relay/backend/graph_executor_codegen.py +++ b/python/tvm/relay/backend/graph_executor_codegen.py @@ -53,7 +53,7 @@ def __init__(self, mod, target): self._get_irmodule = self._mod["get_irmodule"] self._setup(mod, target) - def _setup(self, mod, target): + def _setup(self, mod, target: Dict[int, Target]): tgts = {} if isinstance(target, dict): for dev, tgt in target.items(): diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 54a10add2f07..addc3e2bb9b0 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -48,6 +48,7 @@ namespace backend { using IntegerArray = Array; using StorageMap = std::unordered_map; +using namespace tec; /** * This is an on demand allocator for AOT. A new temporary @@ -518,7 +519,7 @@ class AOTExecutorCodegen : public ExprVisitor { /*! \brief input and output variables belonging to the main function signature */ Array main_signature_; /*! \brief target device */ - tec::TargetMap targets_; + TargetMap targets_; /*! \brief target host */ Target target_host_; /*! @@ -555,7 +556,7 @@ class AOTExecutorCodegen : public ExprVisitor { std::vector return_sid_; public: - AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) + AOTExecutorCodegen(runtime::Module* mod, const TargetMap& targets, Target target_host) : mod_(mod), targets_(targets), target_host_(target_host), @@ -570,7 +571,7 @@ class AOTExecutorCodegen : public ExprVisitor { StaticMemoryPlan memory_plan(initial_storage_map); // Build a map from each operation to device. - tec::DeviceMap device_context_map; + DeviceMap device_context_map; for (const auto& it : memory_plan->expr_to_storage_info) { auto expr = it.first; auto storage_info = it.second; @@ -586,8 +587,9 @@ class AOTExecutorCodegen : public ExprVisitor { // to instead explicitly lowering the incoming IRModule, and then // performing the preexisting AOT executor code generation phase. IRModule mod = IRModule::FromExpr(func); - auto lowered_module = tec::LowerTE( - mod, targets_, device_context_map, memory_plan, mod_name, [this](Function func) { + + IRModule new_mod = + LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) { // We need to maintain the constant map for external // functions so we pass this processing function which // allows us to process each function as we lower it. @@ -598,9 +600,10 @@ class AOTExecutorCodegen : public ExprVisitor { // TODO(@areusch, @jroesch): We should refactor this to // execute as a further pass, instead writing data to the // lowering process directly. - tec::UpdateFunctionMetadata(func, this->function_metadata_); - }); + UpdateFunctionMetadata(func, this->function_metadata_); + })(mod); + LoweredModule lowered_module = IRModuleToLoweredModule(new_mod); function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); auto lowered_main = lowered_module.main_module->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); @@ -736,7 +739,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { private: void init(void* mod, Map tmp) { - tec::TargetMap targets; + TargetMap targets; Target target_host; for (const auto& it : tmp) { auto dev_type = it.first.as(); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index cc54a52be200..aeaece4a912e 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -55,6 +55,7 @@ using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; using GraphInputObjectPtr = std::shared_ptr; using GraphOpObjectPtr = std::shared_ptr; +using namespace tec; /*! \brief Node types */ enum GraphNodeType { @@ -183,7 +184,7 @@ class GraphOpNode : public GraphNode { */ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> { public: - GraphExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets) : mod_(mod) { + GraphExecutorCodegen(runtime::Module* mod, const TargetMap& targets) : mod_(mod) { targets_ = targets; } @@ -209,7 +210,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorexpr_to_storage_info) { auto expr = it.first; auto storage_info = it.second; @@ -221,8 +222,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunction_metadata_); - }); + UpdateFunctionMetadata(func, this->function_metadata_); + })(mod); + LoweredModule lowered_module = IRModuleToLoweredModule(new_mod); function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); auto main_module = lowered_module.main_module; main_module = relay::transform::InferType()(main_module); @@ -579,7 +581,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> var_map_; /*! \brief target device */ - tec::TargetMap targets_; + TargetMap targets_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). * These are take as inputs to the GraphExecutor. @@ -608,7 +610,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { << "runtime::Module mod and Map targets"; void* mod = args[0]; Map tmp = args[1]; - tec::TargetMap targets; + TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); ICHECK(dev_type); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 93fcf73b17a2..98605a429a1d 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -420,6 +420,13 @@ class LowerTensorExprMutator : public ExprMutator { return {ext_func->prim_fn_var, Attrs()}; } + ICHECK_GE(device_context_map_.count(expr), 0) + << "Could not find an entry in the device context map for " << PrettyPrint(expr) + << "The memory planning was either not performed for this precise node, or there is bug " + "in the memory planner."; + + auto& device_context = this->device_context_map_[expr]; + target = GetTargetFromInteger(device_context.device_type, targets_); // Non-External Relay Function DLOG(INFO) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func); @@ -593,6 +600,14 @@ Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); } +/*! + * \brief Obtain the Target from the device type. + * If homogenous compilation, this will return the only target. + * If heteregenous compilation, this will select associated using the targets_ Map. + * + * \param dev_type + * \return Target + */ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { if (targets.size() == 1) { // The homogeneous execution case, return the only target. @@ -749,8 +764,6 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar relay_primfuncs); } -// TODO(@electriclilies): Is the function passed in here relay_func?? -// Also should this be inlined? /*! * \brief A function to create the function metadata for an input function (ie calculate buffer * input/output sizes) @@ -830,9 +843,6 @@ void UpdateFunctionMetadata(Function relay_func, function_metadata.Set(prim_fn_var.value()->name_hint, fi); } -// TODO(mbs): Make this an IRModule->IRModule pass by folding LoweredModule back into IRModule. -// Currently we rely on accumulating bindings inside the local TECompiler which we then -// host into the LoweredModule result. LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, backend::StaticMemoryPlan memory_plan, const String& module_name, std::function process_fn) { @@ -875,6 +885,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic return lowered_module; } +IRModule LoweredModuleToIRModule(LoweredModule mod) { + Map unified_funcs; + Map unified_type_defs; + + // copy main module funcs to unified funcs (what target do we need to annotate with here?) + for (const auto& kv : mod.main_module->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + ICHECK(!func->IsInstance()); + unified_funcs.Set(var, func); + } + + // copy the type definitions for the main module + for (const auto& kv : mod.main_module->type_definitions) { + const GlobalTypeVar& ty_var = kv.first; + const TypeData& ty_data = kv.second; + unified_type_defs.Set(ty_var, ty_data); + } + // Move functions in per target IRModule into unified module + // Also move the type definitions + for (const auto& kv : mod.per_target_module) { + const String target = kv.first; + const IRModule target_module = kv.second; + // Move the per module functions, and annotate the funcs with their target + for (const auto& kv : target_module->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + ICHECK(func->IsInstance()) + << "We expect the target_module to contain only PrimFuncs at this point, but got " + << func->GetTypeKey(); + tir::PrimFunc primFunc = WithAttr(Downcast(std::move(func)), attr::kTarget, + runtime::String(target)); + unified_funcs.Set(var, primFunc); + } + + // Move the type definitions for the per target IRModule + for (const auto& kv : target_module->type_definitions) { + const GlobalTypeVar& ty_var = kv.first; + const TypeData& ty_data = kv.second; + unified_type_defs.Set(ty_var, ty_data); + } + } + + IRModule ret_mod = + WithAttr(IRModule(unified_funcs, unified_type_defs), "external_mods", mod.external_mods); + ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info); + return ret_mod; +} + +LoweredModule IRModuleToLoweredModule(IRModule mod) { + Map main_mod_funcs; + Map> target_funcs; + for (const auto& kv : mod->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance()) { + main_mod_funcs.Set(var, func); + } else if (func->IsInstance()) { + // Extract target + auto target = func->GetAttr(attr::kTarget); + ICHECK(!target) << "Target should be set at this point"; + + // Put the function in target_funcs + if (!target_funcs.count(target.value())) { + // Initialize the map and put it in target_funcs + Map funcs; + funcs.Set(var, func); + target_funcs.Set(target.value(), funcs); + + } else { + // The map is initialized, so just add the function. + Map funcs = target_funcs.at(target.value()); + funcs.Set(var, func); + } + } else { + LOG(FATAL) + << "The function types in the IRModule should be RelayFunction or PrimFunc, but got " + << func->GetTypeKey(); + } + } + // Create the per_target_module map + Map per_target_modules; + for (const auto& kv : target_funcs) { + String target = kv.first; + Map funcs = kv.second; + // Here, we just copy the type defs to every module. Since TIR doesn't use the type defs, + // this duplication should be OK. + per_target_modules.Set(target, IRModule(funcs, mod->type_definitions)); + } + LoweredModule lowered_module; + lowered_module.main_module = IRModule(main_mod_funcs, mod->type_definitions); + lowered_module.per_target_module = per_target_modules; + + // Extract external modules and main func info, add to lowered module if they exist + auto external_mods = mod->GetAttr>("external_mods"); + if (external_mods) { + lowered_module.external_mods = external_mods.value(); + } + auto main_func_info = mod->GetAttr("main_func_info"); + if (main_func_info) { + lowered_module.main_func_info = main_func_info.value(); + } + return lowered_module; +} + +Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn) { + runtime::TypedPackedFunc pass_func = [=](IRModule module, + PassContext ctx) { + return LoweredModuleToIRModule( + LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn)); + }; + return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {}); +} } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 8376b99d79cd..7c6e0320a588 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -173,6 +173,28 @@ void UpdateFunctionMetadata(Function relay_func, */ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); +/*! \brief Utility to convert a LoweredModule to an IRModule. + * + * This function takes all the target specific modules in LoweredModule and + * annotates their functions with the correct target, and puts all those functions + * in one IRModule. + * The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase. + * + * \param mod The LoweredModule to convert. + * \return The IRModule form of the input LoweredModule. + */ +IRModule LoweredModuleToIRModule(LoweredModule mod); + +/*! \brief Utility to convert an IRModule to a LoweredModule. + * + * This function takes all the functions in the IRModule and moves them into target-specific + * IRModules stored inside a LoweredModule. + * The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase. + * \param mod The IRModule to convert. + * \return The LoweredModule form of the input IRModule. + */ +LoweredModule IRModuleToLoweredModule(IRModule mod); + /*! \brief Lower an IRModule's primitive functions to TIR. * * This is the "back half" of the Relay compiler which lowers "primitive functions" @@ -184,12 +206,15 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); * \param device_map An analysis result mapping each sub-expression to a device. * \return The lowered module, see above. */ -// TODO(@electriclilies): Not sure if this default initialization is correct... LoweredModule LowerTE( const IRModule& module, TargetMap targets, DeviceMap device_map, backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](Function f) {}); +using namespace transform; +Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn); } // namespace tec } // namespace relay } // namespace tvm From 15ca3f3223f6fff04acf911e68d9eb996bf46c0d Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 19 Aug 2021 14:47:15 -0700 Subject: [PATCH 02/12] Fix bad rebase --- src/relay/backend/te_compiler.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 98605a429a1d..ef98d2aee3bb 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -420,13 +420,6 @@ class LowerTensorExprMutator : public ExprMutator { return {ext_func->prim_fn_var, Attrs()}; } - ICHECK_GE(device_context_map_.count(expr), 0) - << "Could not find an entry in the device context map for " << PrettyPrint(expr) - << "The memory planning was either not performed for this precise node, or there is bug " - "in the memory planner."; - - auto& device_context = this->device_context_map_[expr]; - target = GetTargetFromInteger(device_context.device_type, targets_); // Non-External Relay Function DLOG(INFO) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func); From 5a97dd099dbb9b155494b03218d3316236a6af74 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 19 Aug 2021 19:14:15 -0700 Subject: [PATCH 03/12] Address 1st round of comments --- .../relay/backend/graph_executor_codegen.py | 2 +- src/relay/backend/aot_executor_codegen.cc | 13 ++- src/relay/backend/graph_executor_codegen.cc | 13 ++- src/relay/backend/te_compiler.cc | 97 ++++++++----------- src/relay/backend/te_compiler.h | 8 +- 5 files changed, 55 insertions(+), 78 deletions(-) diff --git a/python/tvm/relay/backend/graph_executor_codegen.py b/python/tvm/relay/backend/graph_executor_codegen.py index a5d3956fc02b..58717a0ab482 100644 --- a/python/tvm/relay/backend/graph_executor_codegen.py +++ b/python/tvm/relay/backend/graph_executor_codegen.py @@ -53,7 +53,7 @@ def __init__(self, mod, target): self._get_irmodule = self._mod["get_irmodule"] self._setup(mod, target) - def _setup(self, mod, target: Dict[int, Target]): + def _setup(self, mod, target): tgts = {} if isinstance(target, dict): for dev, tgt in target.items(): diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index addc3e2bb9b0..942bc0d1d44a 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -48,7 +48,6 @@ namespace backend { using IntegerArray = Array; using StorageMap = std::unordered_map; -using namespace tec; /** * This is an on demand allocator for AOT. A new temporary @@ -519,7 +518,7 @@ class AOTExecutorCodegen : public ExprVisitor { /*! \brief input and output variables belonging to the main function signature */ Array main_signature_; /*! \brief target device */ - TargetMap targets_; + tec::TargetMap targets_; /*! \brief target host */ Target target_host_; /*! @@ -556,7 +555,7 @@ class AOTExecutorCodegen : public ExprVisitor { std::vector return_sid_; public: - AOTExecutorCodegen(runtime::Module* mod, const TargetMap& targets, Target target_host) + AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) : mod_(mod), targets_(targets), target_host_(target_host), @@ -571,7 +570,7 @@ class AOTExecutorCodegen : public ExprVisitor { StaticMemoryPlan memory_plan(initial_storage_map); // Build a map from each operation to device. - DeviceMap device_context_map; + tec::DeviceMap device_context_map; for (const auto& it : memory_plan->expr_to_storage_info) { auto expr = it.first; auto storage_info = it.second; @@ -600,10 +599,10 @@ class AOTExecutorCodegen : public ExprVisitor { // TODO(@areusch, @jroesch): We should refactor this to // execute as a further pass, instead writing data to the // lowering process directly. - UpdateFunctionMetadata(func, this->function_metadata_); + tec::UpdateFunctionMetadata(func, this->function_metadata_); })(mod); - LoweredModule lowered_module = IRModuleToLoweredModule(new_mod); + tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod); function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); auto lowered_main = lowered_module.main_module->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); @@ -739,7 +738,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { private: void init(void* mod, Map tmp) { - TargetMap targets; + tec::TargetMap targets; Target target_host; for (const auto& it : tmp) { auto dev_type = it.first.as(); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index aeaece4a912e..486a6dcd7d87 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -55,7 +55,6 @@ using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; using GraphInputObjectPtr = std::shared_ptr; using GraphOpObjectPtr = std::shared_ptr; -using namespace tec; /*! \brief Node types */ enum GraphNodeType { @@ -184,7 +183,7 @@ class GraphOpNode : public GraphNode { */ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> { public: - GraphExecutorCodegen(runtime::Module* mod, const TargetMap& targets) : mod_(mod) { + GraphExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets) : mod_(mod) { targets_ = targets; } @@ -210,7 +209,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorexpr_to_storage_info) { auto expr = it.first; auto storage_info = it.second; @@ -234,10 +233,10 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunction_metadata_); + tec::UpdateFunctionMetadata(func, this->function_metadata_); })(mod); - LoweredModule lowered_module = IRModuleToLoweredModule(new_mod); + tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod); function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); auto main_module = lowered_module.main_module; main_module = relay::transform::InferType()(main_module); @@ -581,7 +580,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> var_map_; /*! \brief target device */ - TargetMap targets_; + tec::TargetMap targets_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). * These are take as inputs to the GraphExecutor. @@ -610,7 +609,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { << "runtime::Module mod and Map targets"; void* mod = args[0]; Map tmp = args[1]; - TargetMap targets; + tec::TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); ICHECK(dev_type); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index ef98d2aee3bb..7e2a74bb3d0f 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -593,14 +593,7 @@ Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); } -/*! - * \brief Obtain the Target from the device type. - * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select associated using the targets_ Map. - * - * \param dev_type - * \return Target - */ + Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { if (targets.size() == 1) { // The homogeneous execution case, return the only target. @@ -879,78 +872,71 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic } IRModule LoweredModuleToIRModule(LoweredModule mod) { - Map unified_funcs; - Map unified_type_defs; + IRModule unified_module; - // copy main module funcs to unified funcs (what target do we need to annotate with here?) - for (const auto& kv : mod.main_module->functions) { - const GlobalVar& var = kv.first; - const BaseFunc& func = kv.second; - ICHECK(!func->IsInstance()); - unified_funcs.Set(var, func); - } - - // copy the type definitions for the main module + // Copy the main module and its typedefs + unified_module->Update(mod.main_module); for (const auto& kv : mod.main_module->type_definitions) { - const GlobalTypeVar& ty_var = kv.first; - const TypeData& ty_data = kv.second; - unified_type_defs.Set(ty_var, ty_data); + unified_module->AddTypeDef(kv.first, kv.second); } - // Move functions in per target IRModule into unified module - // Also move the type definitions + + // Annotate the per-target functions with thier target and add them to the unified module for (const auto& kv : mod.per_target_module) { const String target = kv.first; const IRModule target_module = kv.second; - // Move the per module functions, and annotate the funcs with their target + + // Right now, per-target functions are TIR functions, which don't have type definitions, so there should be no type defs in the per_target_modules + size_t ty_def_size = target_module->type_definitions->size(); + ICHECK(ty_def_size == 0) << "Expected there to be no type definitions in the per_target_modules, but found " << ty_def_size; + for (const auto& kv : target_module->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; ICHECK(func->IsInstance()) << "We expect the target_module to contain only PrimFuncs at this point, but got " << func->GetTypeKey(); + // TODO(@electriclilies): Change to Target object if possible tir::PrimFunc primFunc = WithAttr(Downcast(std::move(func)), attr::kTarget, runtime::String(target)); - unified_funcs.Set(var, primFunc); - } - - // Move the type definitions for the per target IRModule - for (const auto& kv : target_module->type_definitions) { - const GlobalTypeVar& ty_var = kv.first; - const TypeData& ty_data = kv.second; - unified_type_defs.Set(ty_var, ty_data); + unified_module->Add(var, primFunc); } } IRModule ret_mod = - WithAttr(IRModule(unified_funcs, unified_type_defs), "external_mods", mod.external_mods); + WithAttr(unified_module, "external_mods", mod.external_mods); ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info); return ret_mod; } LoweredModule IRModuleToLoweredModule(IRModule mod) { - Map main_mod_funcs; - Map> target_funcs; + IRModule main_mod; + // Copy just the TypeDefs from the IRModule to the LoweredModule's main module + // This is the only time we need to do this since there are no TypeDefs in TIR + for (const auto& kv : mod->type_definitions) { + main_mod->AddTypeDef(kv.first, kv.second); + } + + Map per_target_modules; for (const auto& kv : mod->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; if (func->IsInstance()) { - main_mod_funcs.Set(var, func); + main_mod->Add(var, func); } else if (func->IsInstance()) { // Extract target auto target = func->GetAttr(attr::kTarget); - ICHECK(!target) << "Target should be set at this point"; - - // Put the function in target_funcs - if (!target_funcs.count(target.value())) { - // Initialize the map and put it in target_funcs - Map funcs; - funcs.Set(var, func); - target_funcs.Set(target.value(), funcs); - + ICHECK(target) << "Target should be set at this point"; + + // Put the function in per_target_modules + if (!per_target_modules.count(target.value())) { + // Initialize the IRModule for this target and add the function + IRModule target_module; + target_module->Add(var, func); + per_target_modules.Set(target.value(), target_module); } else { - // The map is initialized, so just add the function. - Map funcs = target_funcs.at(target.value()); - funcs.Set(var, func); + // The IRModule for this target is initialized, so just add the function. + IRModule target_module = per_target_modules.at(target.value()); + target_module->Add(var, func); } } else { LOG(FATAL) @@ -958,17 +944,10 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { << func->GetTypeKey(); } } - // Create the per_target_module map - Map per_target_modules; - for (const auto& kv : target_funcs) { - String target = kv.first; - Map funcs = kv.second; - // Here, we just copy the type defs to every module. Since TIR doesn't use the type defs, - // this duplication should be OK. - per_target_modules.Set(target, IRModule(funcs, mod->type_definitions)); - } + + // Put the LoweredModule together LoweredModule lowered_module; - lowered_module.main_module = IRModule(main_mod_funcs, mod->type_definitions); + lowered_module.main_module = main_mod; lowered_module.per_target_module = per_target_modules; // Extract external modules and main func info, add to lowered module if they exist diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 7c6e0320a588..e6efe7592afe 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -67,7 +67,7 @@ struct EnumClassHash { } }; -// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake +// TODO(@jroesch, @chrisS) these shoumakeld be a tvm::Map for uniformity sake // we should a version of context which works in Map using TargetMap = std::unordered_map; using DeviceMap = @@ -166,7 +166,8 @@ void UpdateFunctionMetadata(Function relay_func, /*! * \brief Obtain the Target from the device type. * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select associated using the targets_ Map. + * If heteregenous compilation, this will select the associated target using the + * targets_ Map. * * \param dev_type * \return Target @@ -211,8 +212,7 @@ LoweredModule LowerTE( backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](Function f) {}); -using namespace transform; -Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, +transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, backend::StaticMemoryPlan memory_plan, const String& module_name, std::function process_fn); } // namespace tec From 818bf0f3e5b1a195d468984abb4cb695a3b847ff Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 19 Aug 2021 19:43:02 -0700 Subject: [PATCH 04/12] Use tir kTarget instead of relay one --- include/tvm/relay/function.h | 3 --- src/relay/backend/te_compiler.cc | 7 ++++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index d6a2e06787ff..9170bc53ea02 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -144,9 +144,6 @@ constexpr const char* kComposite = "Composite"; constexpr const char* kInline = "Inline"; /*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; -/*! \brief Indicate the target that the function should be lowered to. */ -constexpr const char* kTarget = "Target"; - /*! \brief Mark the function as only composed of reshape operations. */ constexpr const char* kReshapeOnly = "relay.reshape_only"; } // namespace attr diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 7e2a74bb3d0f..cc7fd8259aad 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -20,6 +20,7 @@ #include "te_compiler.h" #include +#include #include #include #include @@ -886,7 +887,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { const IRModule target_module = kv.second; // Right now, per-target functions are TIR functions, which don't have type definitions, so there should be no type defs in the per_target_modules - size_t ty_def_size = target_module->type_definitions->size(); + size_t ty_def_size = target_module->type_definitions.size(); ICHECK(ty_def_size == 0) << "Expected there to be no type definitions in the per_target_modules, but found " << ty_def_size; for (const auto& kv : target_module->functions) { @@ -896,7 +897,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { << "We expect the target_module to contain only PrimFuncs at this point, but got " << func->GetTypeKey(); // TODO(@electriclilies): Change to Target object if possible - tir::PrimFunc primFunc = WithAttr(Downcast(std::move(func)), attr::kTarget, + tir::PrimFunc primFunc = WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, runtime::String(target)); unified_module->Add(var, primFunc); } @@ -924,7 +925,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->Add(var, func); } else if (func->IsInstance()) { // Extract target - auto target = func->GetAttr(attr::kTarget); + auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target) << "Target should be set at this point"; // Put the function in per_target_modules From ad0059b2eb78531cd5ea79dbac1704c067eb825d Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 19 Aug 2021 20:04:18 -0700 Subject: [PATCH 05/12] Change target string to Target obj --- src/relay/backend/aot_executor_codegen.cc | 9 +++--- src/relay/backend/interpreter.cc | 14 ++++----- src/relay/backend/te_compiler.cc | 37 ++++++++++++----------- src/relay/backend/te_compiler.h | 8 ++--- src/relay/backend/utils.h | 2 +- 5 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 942bc0d1d44a..2b88f0489321 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -669,11 +669,10 @@ class AOTExecutorCodegen : public ExprVisitor { ret.lowered_funcs = lowered_module.per_target_module; ret.external_mods = lowered_module.external_mods; - auto target_host_str = target_host_->str(); - if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { - ret.lowered_funcs[target_host_str]->Update(mod_run); + if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) { + ret.lowered_funcs[target_host_]->Update(mod_run); } else { - ret.lowered_funcs.Set(target_host_str, mod_run); + ret.lowered_funcs.Set(target_host_, mod_run); } std::vector input_var_names(input_vars_.size()); @@ -778,7 +777,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { return (*it).second.first; } - Map get_irmodule() { return this->output_.lowered_funcs; } + Map get_irmodule() { return this->output_.lowered_funcs; } std::shared_ptr codegen_; LoweredOutput output_; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index af2cbae1f72d..7646a32fd2e3 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -289,7 +289,7 @@ class Interpreter : public ExprFunctor, PatternFunctor { public: // TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule. - Interpreter(IRModule mod, Map per_target_module, Device device, Target target) + Interpreter(IRModule mod, Map per_target_module, Device device, Target target) : mod_(mod), per_target_module_(per_target_module), device_(device), @@ -382,7 +382,7 @@ class Interpreter : public ExprFunctor, // Project out just the function(s) we need. IRModule lowered_projected_mod; - auto mod_itr = per_target_module_.find(target->str()); + auto mod_itr = per_target_module_.find(target); ICHECK(mod_itr != per_target_module_.end()) << "No target module for target '" << target->str() << "'"; const IRModule& target_module = (*mod_itr).second; @@ -407,7 +407,7 @@ class Interpreter : public ExprFunctor, PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint << "' in compiled module for target '" << target->str() << "'"; - compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func); + compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); } // Return just what we need for this call. @@ -874,7 +874,7 @@ class Interpreter : public ExprFunctor, // Map from target key to lowered TIR functions derived from mod_. // Note that primitives are implicitly executed on target_, while shape functions are implicitly // executed on the default 'cpu' host. Thus this map has at most two entries. - Map per_target_module_; + Map per_target_module_; // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. std::unordered_map, PackedFunc, PairHash> @@ -895,7 +895,7 @@ class Interpreter : public ExprFunctor, * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -std::pair> Prepare(IRModule mod, Device device, Target target) { +std::pair> Prepare(IRModule mod, Device device, Target target) { // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq({transform::SimplifyInference(), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' @@ -1014,7 +1014,7 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // and can just eval it directly. expr_to_eval = expr; } - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_with_expr, device, target); std::shared_ptr intrp = std::make_shared( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, @@ -1057,7 +1057,7 @@ ObjectRef Eval(Expr expr, Map type_definitions, std::unordered_set import_set, Device device, Target target) { std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_and_global.first, device, target); Interpreter intrp( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index cc7fd8259aad..2ee88eadfdfb 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -85,18 +85,18 @@ class TECompilerImpl : public TECompilerNode { return LowerShapeFuncInternal(key)->cached_func; } - Map GetLoweredFunctions() { - Map lowered_functions; + Map GetLoweredFunctions() { + Map lowered_functions; for (const auto& it : cache_) { auto source_func = it.first; auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target->str())) { - lowered_functions.Set(target->str(), IRModule(Map({}))); + if (!lowered_functions.count(target)) { + lowered_functions.Set(target, IRModule(Map({}))); } - lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + lowered_functions[target]->Update(lowered_func->cached_func->funcs); } for (const auto& it : shape_func_cache_) { @@ -104,11 +104,11 @@ class TECompilerImpl : public TECompilerNode { auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target->str())) { - lowered_functions.Set(target->str(), IRModule(Map({}))); + if (!lowered_functions.count(target)) { + lowered_functions.Set(target, IRModule(Map({}))); } - lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + lowered_functions[target]->Update(lowered_func->cached_func->funcs); } return lowered_functions; } @@ -594,7 +594,6 @@ Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); } - Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { if (targets.size() == 1) { // The homogeneous execution case, return the only target. @@ -883,12 +882,15 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { // Annotate the per-target functions with thier target and add them to the unified module for (const auto& kv : mod.per_target_module) { - const String target = kv.first; + const Target target = kv.first; const IRModule target_module = kv.second; - // Right now, per-target functions are TIR functions, which don't have type definitions, so there should be no type defs in the per_target_modules + // Right now, per-target functions are TIR functions, which don't have type definitions, so + // there should be no type defs in the per_target_modules size_t ty_def_size = target_module->type_definitions.size(); - ICHECK(ty_def_size == 0) << "Expected there to be no type definitions in the per_target_modules, but found " << ty_def_size; + ICHECK(ty_def_size == 0) + << "Expected there to be no type definitions in the per_target_modules, but found " + << ty_def_size; for (const auto& kv : target_module->functions) { const GlobalVar& var = kv.first; @@ -897,14 +899,13 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { << "We expect the target_module to contain only PrimFuncs at this point, but got " << func->GetTypeKey(); // TODO(@electriclilies): Change to Target object if possible - tir::PrimFunc primFunc = WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, - runtime::String(target)); + tir::PrimFunc primFunc = + WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); unified_module->Add(var, primFunc); } } - IRModule ret_mod = - WithAttr(unified_module, "external_mods", mod.external_mods); + IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods); ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info); return ret_mod; } @@ -917,7 +918,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->AddTypeDef(kv.first, kv.second); } - Map per_target_modules; + Map per_target_modules; for (const auto& kv : mod->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; @@ -925,7 +926,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->Add(var, func); } else if (func->IsInstance()) { // Extract target - auto target = func->GetAttr(tvm::attr::kTarget); + Optional target = func->GetAttr(tvm::attr::kTarget); ICHECK(target) << "Target should be set at this point"; // Put the function in per_target_modules diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index e6efe7592afe..93a64909c376 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -97,7 +97,7 @@ class TECompilerNode : public Object { virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; /* Return all functions which have been lowered by the compiler, keyed by target. */ - virtual Map GetLoweredFunctions() = 0; + virtual Map GetLoweredFunctions() = 0; /*! * \brief Just in time compile to get a PackedFunc. @@ -144,7 +144,7 @@ struct LoweredModule { /*! \brief The module which contains the Relay code. */ IRModule main_module; /*! \brief The module which contains per target code. */ - Map per_target_module; + Map per_target_module; /*! \brief The external runtime modules which must be combined with the lowered code. */ Array external_mods; // TODO(@electriclilies): THis might need to become a map @@ -213,8 +213,8 @@ LoweredModule LowerTE( ProcessFn process_fn = [](Function f) {}); transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, - std::function process_fn); + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn); } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a0c7a5aad26d..bf13715b7d46 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -139,7 +139,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type); */ struct LoweredOutput { std::string graph_json; - Map lowered_funcs; + Map lowered_funcs; Array external_mods; Map function_metadata; std::unordered_map> params; From 6be39dd53c543e4518d75843646aa0725b69ed16 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 19 Aug 2021 20:26:40 -0700 Subject: [PATCH 06/12] removing target string causing issues --- src/relay/backend/interpreter.cc | 4 ++-- src/relay/backend/te_compiler.h | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 7646a32fd2e3..4c3952410b22 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -373,7 +373,7 @@ class Interpreter : public ExprFunctor, */ PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array& all_tir_fn_vars, Target target) { - std::pair packed_func_key(target->str(), tir_fn_var->name_hint); + std::pair packed_func_key(target, tir_fn_var->name_hint); auto packed_itr = compiled_packed_funcs_.find(packed_func_key); if (packed_itr != compiled_packed_funcs_.end()) { // Already compiled. @@ -877,7 +877,7 @@ class Interpreter : public ExprFunctor, Map per_target_module_; // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. - std::unordered_map, PackedFunc, PairHash> + std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; // Unique device on which primitives (but not shape functions) will be executed. // (For simplicity we only run the interpreter on a single device.) diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 93a64909c376..26aab969dfbb 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -201,10 +201,13 @@ LoweredModule IRModuleToLoweredModule(IRModule mod); * This is the "back half" of the Relay compiler which lowers "primitive functions" * to TE expressions, schedules them, and then to TIR. * - * \param compiler The TE-to-TIR compliler (which caches lowered functions) * \param module The IRModule. * \param targets The mapping for devices to targets. * \param device_map An analysis result mapping each sub-expression to a device. + * \param memory_plan The memory plan used during lowering + * \param module_name The name of this module + * \param process_fn Callback allowing one-level up code generators to process + * each function that we lower * \return The lowered module, see above. */ LoweredModule LowerTE( @@ -212,6 +215,21 @@ LoweredModule LowerTE( backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](Function f) {}); +/*! \brief Pass to lower an IRModule's primitive functions to TIR. + * + * This is the "back half" of the Relay compiler which lowers "primitive functions" + * to TE expressions, schedules them, and then to TIR. This Pass calls LowerTE, and + * uses LoweredModuleToIRModule utility to convert the output LowerTE's output + * LoweredModule into an IRModule before returning it. + * + * \param targets The mapping for devices to targets. + * \param device_map An analysis result mapping each sub-expression to a device. + * \param memory_plan The memory plan used during lowering + * \param module_name The name of this module + * \param process_fn Callback allowing one-level up code generators to process + * each function that we lower + * \returns The pass which lowers primative functions to TIR + */ transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, backend::StaticMemoryPlan memory_plan, const String& module_name, std::function process_fn); From d698857945d6434b57151a74ff0aa26af08fdc6e Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 20 Aug 2021 09:11:07 -0700 Subject: [PATCH 07/12] Fix typos --- src/relay/backend/te_compiler.cc | 2 +- src/relay/backend/te_compiler.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 2ee88eadfdfb..24bf4d7be317 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -880,7 +880,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { unified_module->AddTypeDef(kv.first, kv.second); } - // Annotate the per-target functions with thier target and add them to the unified module + // Annotate the per-target functions with their target and add them to the unified module for (const auto& kv : mod.per_target_module) { const Target target = kv.first; const IRModule target_module = kv.second; diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 26aab969dfbb..3c199ccbc387 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -67,7 +67,7 @@ struct EnumClassHash { } }; -// TODO(@jroesch, @chrisS) these shoumakeld be a tvm::Map for uniformity sake +// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake // we should a version of context which works in Map using TargetMap = std::unordered_map; using DeviceMap = @@ -166,7 +166,7 @@ void UpdateFunctionMetadata(Function relay_func, /*! * \brief Obtain the Target from the device type. * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select the associated target using the + * If heterogeneous compilation, this will select the associated target using the * targets_ Map. * * \param dev_type From 120b14a450db055a955186432c52d8a998818ede Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 20 Aug 2021 09:29:34 -0700 Subject: [PATCH 08/12] Revert target str -> target obj changes --- src/relay/backend/aot_executor_codegen.cc | 9 +++++---- src/relay/backend/interpreter.cc | 18 +++++++++--------- src/relay/backend/te_compiler.cc | 22 +++++++++++----------- src/relay/backend/te_compiler.h | 4 ++-- src/relay/backend/utils.h | 2 +- 5 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 2b88f0489321..942bc0d1d44a 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -669,10 +669,11 @@ class AOTExecutorCodegen : public ExprVisitor { ret.lowered_funcs = lowered_module.per_target_module; ret.external_mods = lowered_module.external_mods; - if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) { - ret.lowered_funcs[target_host_]->Update(mod_run); + auto target_host_str = target_host_->str(); + if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { + ret.lowered_funcs[target_host_str]->Update(mod_run); } else { - ret.lowered_funcs.Set(target_host_, mod_run); + ret.lowered_funcs.Set(target_host_str, mod_run); } std::vector input_var_names(input_vars_.size()); @@ -777,7 +778,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { return (*it).second.first; } - Map get_irmodule() { return this->output_.lowered_funcs; } + Map get_irmodule() { return this->output_.lowered_funcs; } std::shared_ptr codegen_; LoweredOutput output_; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 4c3952410b22..af2cbae1f72d 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -289,7 +289,7 @@ class Interpreter : public ExprFunctor, PatternFunctor { public: // TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule. - Interpreter(IRModule mod, Map per_target_module, Device device, Target target) + Interpreter(IRModule mod, Map per_target_module, Device device, Target target) : mod_(mod), per_target_module_(per_target_module), device_(device), @@ -373,7 +373,7 @@ class Interpreter : public ExprFunctor, */ PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array& all_tir_fn_vars, Target target) { - std::pair packed_func_key(target, tir_fn_var->name_hint); + std::pair packed_func_key(target->str(), tir_fn_var->name_hint); auto packed_itr = compiled_packed_funcs_.find(packed_func_key); if (packed_itr != compiled_packed_funcs_.end()) { // Already compiled. @@ -382,7 +382,7 @@ class Interpreter : public ExprFunctor, // Project out just the function(s) we need. IRModule lowered_projected_mod; - auto mod_itr = per_target_module_.find(target); + auto mod_itr = per_target_module_.find(target->str()); ICHECK(mod_itr != per_target_module_.end()) << "No target module for target '" << target->str() << "'"; const IRModule& target_module = (*mod_itr).second; @@ -407,7 +407,7 @@ class Interpreter : public ExprFunctor, PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint << "' in compiled module for target '" << target->str() << "'"; - compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); + compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func); } // Return just what we need for this call. @@ -874,10 +874,10 @@ class Interpreter : public ExprFunctor, // Map from target key to lowered TIR functions derived from mod_. // Note that primitives are implicitly executed on target_, while shape functions are implicitly // executed on the default 'cpu' host. Thus this map has at most two entries. - Map per_target_module_; + Map per_target_module_; // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. - std::unordered_map, PackedFunc, PairHash> + std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; // Unique device on which primitives (but not shape functions) will be executed. // (For simplicity we only run the interpreter on a single device.) @@ -895,7 +895,7 @@ class Interpreter : public ExprFunctor, * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -std::pair> Prepare(IRModule mod, Device device, Target target) { +std::pair> Prepare(IRModule mod, Device device, Target target) { // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq({transform::SimplifyInference(), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' @@ -1014,7 +1014,7 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // and can just eval it directly. expr_to_eval = expr; } - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_with_expr, device, target); std::shared_ptr intrp = std::make_shared( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, @@ -1057,7 +1057,7 @@ ObjectRef Eval(Expr expr, Map type_definitions, std::unordered_set import_set, Device device, Target target) { std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_and_global.first, device, target); Interpreter intrp( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 24bf4d7be317..3bf03ce4c7bd 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -85,18 +85,18 @@ class TECompilerImpl : public TECompilerNode { return LowerShapeFuncInternal(key)->cached_func; } - Map GetLoweredFunctions() { - Map lowered_functions; + Map GetLoweredFunctions() { + Map lowered_functions; for (const auto& it : cache_) { auto source_func = it.first; auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target)) { - lowered_functions.Set(target, IRModule(Map({}))); + if (!lowered_functions.count(target->str())) { + lowered_functions.Set(target->str(), IRModule(Map({}))); } - lowered_functions[target]->Update(lowered_func->cached_func->funcs); + lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); } for (const auto& it : shape_func_cache_) { @@ -104,11 +104,11 @@ class TECompilerImpl : public TECompilerNode { auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target)) { - lowered_functions.Set(target, IRModule(Map({}))); + if (!lowered_functions.count(target->str())) { + lowered_functions.Set(target->str(), IRModule(Map({}))); } - lowered_functions[target]->Update(lowered_func->cached_func->funcs); + lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); } return lowered_functions; } @@ -882,7 +882,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { // Annotate the per-target functions with their target and add them to the unified module for (const auto& kv : mod.per_target_module) { - const Target target = kv.first; + const String target = kv.first; const IRModule target_module = kv.second; // Right now, per-target functions are TIR functions, which don't have type definitions, so @@ -918,7 +918,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->AddTypeDef(kv.first, kv.second); } - Map per_target_modules; + Map per_target_modules; for (const auto& kv : mod->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; @@ -926,7 +926,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->Add(var, func); } else if (func->IsInstance()) { // Extract target - Optional target = func->GetAttr(tvm::attr::kTarget); + Optional target = func->GetAttr(tvm::attr::kTarget); ICHECK(target) << "Target should be set at this point"; // Put the function in per_target_modules diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 3c199ccbc387..91379e6364f9 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -97,7 +97,7 @@ class TECompilerNode : public Object { virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; /* Return all functions which have been lowered by the compiler, keyed by target. */ - virtual Map GetLoweredFunctions() = 0; + virtual Map GetLoweredFunctions() = 0; /*! * \brief Just in time compile to get a PackedFunc. @@ -144,7 +144,7 @@ struct LoweredModule { /*! \brief The module which contains the Relay code. */ IRModule main_module; /*! \brief The module which contains per target code. */ - Map per_target_module; + Map per_target_module; /*! \brief The external runtime modules which must be combined with the lowered code. */ Array external_mods; // TODO(@electriclilies): THis might need to become a map diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index bf13715b7d46..a0c7a5aad26d 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -139,7 +139,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type); */ struct LoweredOutput { std::string graph_json; - Map lowered_funcs; + Map lowered_funcs; Array external_mods; Map function_metadata; std::unordered_map> params; From 9015163a27a56d7545ac8ea12c8983715208f1e4 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 20 Aug 2021 15:50:46 -0700 Subject: [PATCH 09/12] Don't use Update : IRModule because it is broken --- src/relay/backend/te_compiler.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 3bf03ce4c7bd..2ac627594841 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -875,7 +875,9 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { IRModule unified_module; // Copy the main module and its typedefs - unified_module->Update(mod.main_module); + for (const auto& kv : mod.main_module->functions) { + unified_module->Add(kv.first, kv.second); + } for (const auto& kv : mod.main_module->type_definitions) { unified_module->AddTypeDef(kv.first, kv.second); } From ce57b2ebe1e51380a0ddabccfa1ec7dd9b12ca32 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 20 Aug 2021 17:23:50 -0700 Subject: [PATCH 10/12] Fix check --- src/relay/backend/te_compiler.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 2ac627594841..a1df16b4c8b6 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -897,13 +897,18 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { for (const auto& kv : target_module->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; - ICHECK(func->IsInstance()) - << "We expect the target_module to contain only PrimFuncs at this point, but got " - << func->GetTypeKey(); - // TODO(@electriclilies): Change to Target object if possible - tir::PrimFunc primFunc = + if (func->IsInstance()) { + tir::PrimFunc primFunc = WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); - unified_module->Add(var, primFunc); + unified_module->Add(var, primFunc); + } else if (func->IsInstance()) { + relay::Function relayFunc = + WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); + unified_module->Add(var, relayFunc); + } else { + LOG(FATAL) << "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found " + << func->GetTypeKey(); + } } } From 8521ce2650035712589f526608a5488ea69d3095 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 23 Aug 2021 11:54:16 -0700 Subject: [PATCH 11/12] flaky test? From 772ae10379bf725b2de9e44a687f8a4becb92dd3 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 23 Aug 2021 12:22:06 -0700 Subject: [PATCH 12/12] lint --- src/relay/backend/te_compiler.cc | 9 +++++---- src/relay/backend/te_compiler.h | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index a1df16b4c8b6..71ac752ec680 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -899,15 +899,16 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { const BaseFunc& func = kv.second; if (func->IsInstance()) { tir::PrimFunc primFunc = - WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); + WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); unified_module->Add(var, primFunc); } else if (func->IsInstance()) { relay::Function relayFunc = - WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); + WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); unified_module->Add(var, relayFunc); } else { - LOG(FATAL) << "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found " - << func->GetTypeKey(); + LOG(FATAL) + << "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found " + << func->GetTypeKey(); } } } diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 91379e6364f9..e9cfb0d62e66 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -223,7 +223,7 @@ LoweredModule LowerTE( * LoweredModule into an IRModule before returning it. * * \param targets The mapping for devices to targets. - * \param device_map An analysis result mapping each sub-expression to a device. + * \param device_context_map An analysis result mapping each sub-expression to a device. * \param memory_plan The memory plan used during lowering * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process