diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 70779ac58abf..ad9ba1b2069d 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -38,8 +38,8 @@ #include #include -#include "te_compiler.h" -#include "utils.h" +#include "./te_compiler.h" +#include "./utils.h" namespace tvm { namespace relay { @@ -583,8 +583,16 @@ class AOTExecutorCodegen : public MixedModeVisitor { // performing the preexisting AOT executor code generation phase. IRModule mod = IRModule::FromExpr(func); + backend::FunctionInfo func_info; + + if (memory_plan.defined()) { + // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize + func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info); + mod = WithAttr(mod, "main_func_info", func_info); + } + IRModule lowered_mod = - LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) { + tec::LowerTEPass(targets_, device_context_map, mod_name, [this](Function func) { // We need to maintain the constant map for external // functions so we pass this processing function which // allows us to process each function as we lower it. @@ -661,7 +669,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { Optional main_func_info = lowered_mod->GetAttr("main_func_info"); - ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point."; + main_func_info.value()->workspace_sizes.Set(target_host_, main_workspace_size); function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value()); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index aca95db34c4e..92e7568d9f38 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -36,8 +36,8 @@ #include #include -#include "te_compiler.h" -#include "utils.h" +#include "./te_compiler.h" +#include "./utils.h" namespace tvm { namespace relay { @@ -221,8 +221,17 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorexpr_to_storage_info); + mod = WithAttr(mod, "main_func_info", func_info); + } + IRModule lowered_mod = - LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) { + tec::LowerTEPass(targets_, device_context_map, mod_name_, [this](Function func) { // We need to maintain the constant map for external // functions so we pass this processing function which // allows us to process each function as we lower it. @@ -238,7 +247,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator main_func_info = lowered_mod->GetAttr("main_func_info"); - ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point."; + function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value()); Function lowered_main_func = Downcast(lowered_mod->Lookup("main")); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index df14b9e078b6..d87cf9811bc7 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -542,7 +542,7 @@ class Interpreter : public ExprFunctor, * * @param prim_fn_var Global bound to lowered primitive. * @param all_prim_fn_vars All globals references by lowered primitive, plus prim_fn_var itself. - * @param prim_shape_fn_var Global bound to lowered shape function for primitive, if neeeded. + * @param prim_shape_fn_var Global bound to lowered shape function for primitive, if needed. * @param all_prim_shape_fn_vars All globals references by lowered shape function, plus * prim_shape_fn_var itself. * @param prim_shape_fn_states Records whether shape and/or data is needed by the dynamic @@ -763,7 +763,7 @@ class Interpreter : public ExprFunctor, ObjectRef VisitExpr_(const TupleGetItemNode* op) final { ObjectRef val = Eval(op->tuple); const auto* adt_obj = val.as(); - ICHECK(adt_obj) << "interal error: when evaluating TupleGetItem expected an ADT value"; + ICHECK(adt_obj) << "internal error: when evaluating TupleGetItem expected an ADT value"; auto adt = GetRef(adt_obj); ICHECK_LT(static_cast(op->index), adt.size()) << "internal error: index out of bounds"; return adt[op->index]; @@ -902,21 +902,17 @@ IRModule Prepare(IRModule mod, Device device, Target target) { // All calls to primitives will use the unique target. tec::DeviceMap device_map; - // No need for a memory plan. - backend::StaticMemoryPlan memory_plan; /*=nullptr*/ - // Run minimal transforms on module to establish invariants needed by interpreter. - transform::Sequential seq( - {transform::SimplifyInference(), - // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' - // attribute. - transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), - // eta expand to support constructors in argument position - transform::EtaExpand( - /*expand_constructor=*/true, /*expand_global_var=*/false), - transform::InferType(), - tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp", - [](Function func) { /* no-op */ })}); + transform::Sequential seq({transform::SimplifyInference(), + // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' + // attribute. + transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), + // eta expand to support constructors in argument position + transform::EtaExpand( + /*expand_constructor=*/true, /*expand_global_var=*/false), + transform::InferType(), + tec::LowerTEPass(targets, device_map, /*module_name=*/"intrp", + [](Function func) { /* no-op */ })}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 0393fdfec70d..2e7eb6f9aa6b 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -17,7 +17,7 @@ * under the License. */ -#include "te_compiler.h" +#include "./te_compiler.h" #include #include @@ -42,8 +42,8 @@ #include #include -#include "te_compiler_cache.h" -#include "utils.h" +#include "./te_compiler_cache.h" +#include "./utils.h" namespace tvm { namespace relay { @@ -596,19 +596,7 @@ class LowerTensorExprMutator : public ExprMutator { const Op& debug_op_; }; -Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, - TECompiler compiler, std::function process_fn) { - runtime::TypedPackedFunc pass_func = - [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn, - module_name, compiler); - return Downcast(lower_te.Mutate(func)); - }; - return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); -} - -Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { +Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) { if (targets.size() == 1) { // The homogeneous execution case, return the only target. const auto& it = targets.begin(); @@ -638,26 +626,30 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { } } -/*! - * \brief Update the "main" control function's metadata - * - * \param mod The module - * \param targets Map of targets - * \return function_infos Function info for each function in the module - */ +Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, const String& module_name, + TECompiler compiler, std::function process_fn) { + runtime::TypedPackedFunc pass_func = + [=](Function func, IRModule module, PassContext ctx) { + LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn, + module_name, compiler); + return Downcast(lower_te.Mutate(func)); + }; + return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); +} -backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap targets, +backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets, Map storage_info_map) { CHECK_EQ(mod->functions.size(), 1) << "There should only be one function in the module passed to UpdateMainWorkspaceSize"; Function func = Downcast(mod->Lookup("main")); // This is a Map> - std::unordered_map, EnumClassHash> sid_workspace; + std::unordered_map, backend::EnumClassHash> + sid_workspace; // This is a Map - std::unordered_map device_io; + std::unordered_map device_io; // This is a Map - std::unordered_map device_consts; + std::unordered_map device_consts; // Initialize the mapping from all storage identifiers to workspace sizes, // the amount of device io, and the device constants. @@ -723,7 +715,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar } // This is a Map - std::unordered_map device_workspace; + std::unordered_map device_workspace; // Once we know the sizes of sids, we need to accumulate per device for (const auto& dev_sid_size : sid_workspace) { auto dev = dev_sid_size.first; @@ -746,17 +738,17 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar } for (const auto& dev_and_size : device_workspace) { - auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets); workspace_sizes.Set(tgt, dev_and_size.second); relay_primfuncs.Set(tgt, func); } for (const auto& dev_and_size : device_io) { - auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets); io_sizes.Set(tgt, dev_and_size.second); } for (const auto& dev_and_size : device_consts) { - auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets); constant_sizes.Set(tgt, dev_and_size.second); } @@ -844,20 +836,13 @@ void UpdateFunctionMetadata(Function relay_func, } IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, - std::function process_fn) { + const String& module_name, std::function process_fn) { DLOG(INFO) << "lowering module:\n" << PrettyPrint(module); TECompiler compiler; - backend::FunctionInfo func_info; - if (memory_plan.defined()) { - // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize - func_info = UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info); - } - - auto updated_module = LowerTensorExpr(targets, device_context_map, memory_plan, module_name, - compiler, process_fn)(module); + auto updated_module = + LowerTensorExpr(targets, device_context_map, module_name, compiler, process_fn)(module); // A temporary solution until we can rewrite the auto-scheduler task extraction code to work // in a more reasonable way. @@ -882,7 +867,6 @@ IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_con // Annotate the module with the external modules and function info updated_module = WithAttr(updated_module, "external_mods", compiler->LowerExternalFunctions()); - updated_module = WithAttr(updated_module, "main_func_info", func_info); return updated_module; } @@ -919,12 +903,11 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, +Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, const String& module_name, std::function process_fn) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { - return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn); + return LowerTE(module, targets, device_context_map, module_name, process_fn); }; return tvm::transform::Sequential( {tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()}); diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 9d0eb1078ee0..d5135e6301c4 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -52,24 +52,15 @@ #include "../transforms/infer_layout_utils.h" #include "../transforms/pass_utils.h" #include "./te_compiler_cache.h" -#include "utils.h" +#include "./utils.h" namespace tvm { namespace relay { namespace tec { -// This class is needed to avoid a GCC 5 bug that prevents maps containing enums -// from being compiled. If i386 GCC version is increased, we can remove it. -struct EnumClassHash { - template - std::size_t operator()(T t) const { - return static_cast(t); - } -}; - // TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake // we should a version of context which works in Map -using TargetMap = std::unordered_map; +using TargetMap = std::unordered_map; using DeviceMap = std::unordered_map; using ProcessFn = std::function; @@ -158,6 +149,16 @@ void UpdateFunctionMetadata(Function relay_func, */ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); +/*! + * \brief Update the "main" control function's metadata + * + * \param mod The module + * \param targets Map of targets + * \return function_infos Function info for each function in the module + */ +backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets, + Map storage_info_map); + /*! \brief Utility to separate the functions in an IRModule by Target. * * \param mod The IRModule to extract the per target module from @@ -192,15 +193,13 @@ IRModule LowerTE( * * \param targets The mapping for devices to targets. * \param device_context_map An analysis result mapping each sub-expression to a device. - * \param memory_plan The memory plan used during lowering * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower * \returns The pass which lowers primative functions to TIR */ transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, - std::function process_fn); + const String& module_name, std::function process_fn); } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index cf8a2dd4b8e0..ae8d7d2c2360 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -146,6 +146,17 @@ struct LoweredOutput { runtime::Metadata metadata; }; +/*! + * \brief This class is needed to avoid a GCC 5 bug that prevents maps containing enums from being + compiled. If i386 GCC version is increased, we can remove it. + */ +struct EnumClassHash { + template + std::size_t operator()(T t) const { + return static_cast(t); + } +}; + /*! * \brief A helper to expand the params by adding the ones used in a given expression. */