From 6ed813be2733033530c973cb5d05fb9cb92fe268 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 1 Sep 2021 03:59:53 +0800 Subject: [PATCH 1/7] Remove LoweredModule --- include/tvm/runtime/container/map.h | 2 +- include/tvm/target/target.h | 1 - src/relay/backend/aot_executor_codegen.cc | 20 ++- src/relay/backend/graph_executor_codegen.cc | 21 ++- src/relay/backend/interpreter.cc | 42 +++--- src/relay/backend/te_compiler.cc | 159 ++++++++------------ src/relay/backend/te_compiler.h | 57 ++----- 7 files changed, 130 insertions(+), 172 deletions(-) diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h index 671e38b83581..3fe4f697bb9e 100644 --- a/include/tvm/runtime/container/map.h +++ b/include/tvm/runtime/container/map.h @@ -1353,7 +1353,7 @@ class Map : public ObjectRef { * Otherwise make a new copy of the array to ensure the current handle * hold a unique copy. * - * \return Handle to the internal node container(which ganrantees to be unique) + * \return Handle to the internal node container(which guarantees to be unique) */ MapNode* CopyOnWrite() { if (data_.get() == nullptr) { diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index deec662e74ad..64a1023158e1 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -31,7 +31,6 @@ #include #include -#include #include #include diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index b2e862b22b48..af8ac2b4e023 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -583,7 +583,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // performing the preexisting AOT executor code generation phase. IRModule mod = IRModule::FromExpr(func); - IRModule new_mod = + IRModule lowered_mod = LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) { // We need to maintain the constant map for external // functions so we pass this processing function which @@ -598,9 +598,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { tec::UpdateFunctionMetadata(func, this->function_metadata_); })(mod); - tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod); - function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); - auto lowered_main = lowered_module.main_module->Lookup("main"); + Optional main_func_info = + lowered_mod->GetAttr("main_func_info"); + ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point."; + function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value()); + auto lowered_main = lowered_mod->Lookup("main"); + auto lowered_main_func = GetRef(lowered_main.as()); // Post-lowering storage map for writing main func - this should be the same map as previously @@ -662,8 +665,13 @@ class AOTExecutorCodegen : public MixedModeVisitor { ret.function_metadata = std::move(function_metadata_); - ret.lowered_funcs = lowered_module.per_target_module; - ret.external_mods = lowered_module.external_mods; + Optional> external_modules = + lowered_mod->GetAttr>("external_mods"); + ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point."; + + // This is the point where we separate the functions in the module by target + ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); + ret.external_mods = external_modules.value(); if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) { ret.lowered_funcs[target_host_]->Update(mod_run); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 486a6dcd7d87..932947a5ffec 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -221,7 +221,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunction_metadata_); })(mod); - tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod); - function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); - auto main_module = lowered_module.main_module; + Optional main_func_info = + lowered_mod->GetAttr("main_func_info"); + ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point."; + function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value()); + + // Get only the Relay functions out of the lowered module so we can run type inference on them + IRModule main_module = tec::GetMainModule(lowered_mod); main_module = relay::transform::InferType()(main_module); relay::Function main_func = Downcast(main_module->Lookup("main")); @@ -270,8 +274,13 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(param_storage_ids_[param.first]), param.second))); } ret.function_metadata = std::move(function_metadata_); - ret.lowered_funcs = lowered_module.per_target_module; - ret.external_mods = lowered_module.external_mods; + + Optional> external_modules = + lowered_mod->GetAttr>("external_mods"); + // This is the point where we separate the functions in the module by target + + ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); + ret.external_mods = external_modules.value(); return ret; } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 76b6f9186eb5..82455bdf925c 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -292,7 +292,8 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st class Interpreter : public ExprFunctor, PatternFunctor { public: - // TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule. + // TODO(mbs, electriclilies): Collapse mod and per_target_module once IRModule subsumes + // LoweredModule. Interpreter(IRModule mod, Map per_target_module, Device device, Target target) : mod_(mod), per_target_module_(per_target_module), @@ -902,20 +903,7 @@ class Interpreter : public ExprFunctor, * functions needed by the rewritten module. */ std::pair> Prepare(IRModule mod, Device device, Target target) { - // Run minimal transforms on module to establish invariants needed by interpreter. - transform::Sequential seq({transform::SimplifyInference(), - // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' - // attribute. - transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), - // eta expand to support constructors in argument position - transform::EtaExpand( - /*expand_constructor=*/true, /*expand_global_var=*/false), - transform::InferType()}); - - transform::PassContext pass_ctx = transform::PassContext::Current(); - With ctx(pass_ctx); - mod = seq(mod); - + // Things to initialize to pass into tec::LowerTEPass // We only have one device-specific target. tec::TargetMap targets = {{device.device_type, target}}; @@ -925,13 +913,25 @@ std::pair> Prepare(IRModule mod, Device device, // No need for a memory plan. backend::StaticMemoryPlan memory_plan; /*=nullptr*/ + // Run minimal transforms on module to establish invariants needed by interpreter. + transform::Sequential seq( + {transform::SimplifyInference(), + // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' + // attribute. + transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), + // eta expand to support constructors in argument position + transform::EtaExpand( + /*expand_constructor=*/true, /*expand_global_var=*/false), + transform::InferType(), + tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp", + [](Function func) { /* no-op */ })}); + + transform::PassContext pass_ctx = transform::PassContext::Current(); + With ctx(pass_ctx); + mod = seq(mod); + // Lower all primitive functions reachable from expr. - // TODO(mbs): This should be just another pass in seq above, which requires LoweredModule to - // be merged into IRModule. - LoweredModule lowered_module = - tec::LowerTE(mod, targets, device_map, memory_plan, /*module_name=*/"intrp", - [](Function func) { /* no-op */ }); - return {lowered_module.main_module, lowered_module.per_target_module}; + return {tec::GetMainModule(mod), tec::GetPerTargetModules(mod)}; } /*! \brief Check if an expression could be changed by \p Prepare. diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 06d862b781e1..91a8fcb55dfd 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -85,33 +85,51 @@ class TECompilerImpl : public TECompilerNode { return LowerShapeFuncInternal(key)->cached_func; } - Map GetLoweredFunctions() { - std::unordered_map - lowered_functions; + IRModule GetLoweredFunctions() { + IRModule mod; + // TODO(@electriclilies): This chunk of code is pretty much the same for + // normal cache and shape func cache. Consider making a helper here to do that. + // Additionaly, might be good to overhaul the mod->Update(mod) function (it's broken!) for (const auto& it : cache_) { auto source_func = it.first; + // TODO(@electriclilies): Does the lowered_func module only contain one function? auto lowered_func = it.second; - auto target = source_func->target; - if (!lowered_functions.count(target)) { - lowered_functions[target] = IRModule(Map({})); - } + IRModule lowered_mod = lowered_func->cached_func->funcs; - lowered_functions[target]->Update(lowered_func->cached_func->funcs); + // Annotate functions with their target and put them in the return module + for (auto kv : lowered_mod->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + + if (func->IsInstance()) { + const relay::Function relay_func = Downcast(func); + mod->Update(var, WithAttr(relay_func, tvm::attr::kTarget, source_func->target)); + } else if (func->IsInstance()) { + const tir::PrimFunc& prim_func = Downcast(func); + mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target)); + } else { + LOG(FATAL) << "Expected to find only relay functions and prim functions in the cache, " + "but found: " + << func->GetTypeKey(); + } + } } for (const auto& it : shape_func_cache_) { auto source_func = it.first; auto lowered_func = it.second; auto target = source_func->target; + IRModule lowered_mod = lowered_func->cached_func->funcs; - if (!lowered_functions.count(target)) { - lowered_functions[target] = IRModule(Map({})); + for (auto kv : lowered_mod->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + const tir::PrimFunc& prim_func = Downcast(func); + mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target)); } - - lowered_functions[target]->Update(lowered_func->cached_func->funcs); } - return backend::TargetStrModuleMapToTargetModuleMap(lowered_functions); + return mod; } Array LowerExternalFunctions() { @@ -830,9 +848,9 @@ void UpdateFunctionMetadata(Function relay_func, function_metadata.Set(prim_fn_var.value()->name_hint, fi); } -LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, - std::function process_fn) { +IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn) { DLOG(INFO) << "lowering module:\n" << PrettyPrint(module); TECompiler compiler; @@ -864,76 +882,24 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic (*te_compiler_update_weights)(weight_map); } - LoweredModule lowered_module; - lowered_module.main_module = updated_module; - lowered_module.per_target_module = compiler->GetLoweredFunctions(); - lowered_module.external_mods = compiler->LowerExternalFunctions(); - lowered_module.main_func_info = func_info; - return lowered_module; -} - -IRModule LoweredModuleToIRModule(LoweredModule mod) { - IRModule unified_module; - - // Copy the main module and its typedefs - for (const auto& kv : mod.main_module->functions) { - unified_module->Add(kv.first, kv.second); - } - for (const auto& kv : mod.main_module->type_definitions) { - unified_module->AddTypeDef(kv.first, kv.second); - } + // Copy the lowered functions into the return module + std::cout << "Getting lowered funcs" << std::endl; + updated_module->Update(compiler->GetLoweredFunctions()); - // Annotate the per-target functions with their target and add them to the unified module - for (const auto& kv : mod.per_target_module) { - const Target target = kv.first; - const IRModule target_module = kv.second; - - // Right now, per-target functions are TIR functions, which don't have type definitions, so - // there should be no type defs in the per_target_modules - size_t ty_def_size = target_module->type_definitions.size(); - ICHECK(ty_def_size == 0) - << "Expected there to be no type definitions in the per_target_modules, but found " - << ty_def_size; - - for (const auto& kv : target_module->functions) { - const GlobalVar& var = kv.first; - const BaseFunc& func = kv.second; - if (func->IsInstance()) { - tir::PrimFunc primFunc = - WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); - unified_module->Add(var, primFunc); - } else if (func->IsInstance()) { - relay::Function relayFunc = - WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); - unified_module->Add(var, relayFunc); - } else { - LOG(FATAL) - << "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found " - << func->GetTypeKey(); - } - } - } + // Annotate the module with the external modules and function info + updated_module = WithAttr(updated_module, "external_mods", compiler->LowerExternalFunctions()); + updated_module = WithAttr(updated_module, "main_func_info", func_info); - IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods); - ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info); - return ret_mod; + return updated_module; } -LoweredModule IRModuleToLoweredModule(IRModule mod) { - IRModule main_mod; - // Copy just the TypeDefs from the IRModule to the LoweredModule's main module - // This is the only time we need to do this since there are no TypeDefs in TIR - for (const auto& kv : mod->type_definitions) { - main_mod->AddTypeDef(kv.first, kv.second); - } - - Map per_target_modules; +Map GetPerTargetModules(IRModule mod) { + std::unordered_map + per_target_modules; for (const auto& kv : mod->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; - if (func->IsInstance()) { - main_mod->Add(var, func); - } else if (func->IsInstance()) { + if (func->IsInstance()) { // Extract target Optional target = func->GetAttr(tvm::attr::kTarget); ICHECK(target) << "Target should be set at this point"; @@ -943,34 +909,36 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { // Initialize the IRModule for this target and add the function IRModule target_module; target_module->Add(var, func); - per_target_modules.Set(target.value(), target_module); + per_target_modules[target.value()] = target_module; } else { // The IRModule for this target is initialized, so just add the function. IRModule target_module = per_target_modules.at(target.value()); target_module->Add(var, func); } - } else { + } else if (!func->IsInstance()) { LOG(FATAL) << "The function types in the IRModule should be RelayFunction or PrimFunc, but got " << func->GetTypeKey(); } } + return per_target_modules; +} - // Put the LoweredModule together - LoweredModule lowered_module; - lowered_module.main_module = main_mod; - lowered_module.per_target_module = per_target_modules; - - // Extract external modules and main func info, add to lowered module if they exist - auto external_mods = mod->GetAttr>("external_mods"); - if (external_mods) { - lowered_module.external_mods = external_mods.value(); +IRModule GetMainModule(IRModule mod) { + IRModule main_module; + // Copy the type defs + for (const auto& kv : mod->type_definitions) { + main_module->AddTypeDef(kv.first, kv.second); } - auto main_func_info = mod->GetAttr("main_func_info"); - if (main_func_info) { - lowered_module.main_func_info = main_func_info.value(); + // Copy all Relay functions (we don't include PrimFuncs) + for (auto kv : mod->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance()) { + main_module->Add(var, func); + } } - return lowered_module; + return main_module; } Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, @@ -978,8 +946,7 @@ Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, std::function process_fn) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { - return LoweredModuleToIRModule( - LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn)); + return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn); }; return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {}); } diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 65ba67ac7e1b..6555ababea8c 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -96,8 +96,9 @@ class TECompilerNode : public Object { */ virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; - /* Return all functions which have been lowered by the compiler, keyed by target. */ - virtual Map GetLoweredFunctions() = 0; + /* Return all functions which have been lowered by the compiler in an IRModule, annotated with + * their target. */ + virtual IRModule GetLoweredFunctions() = 0; /*! * \brief Just in time compile to get a PackedFunc. @@ -113,7 +114,7 @@ class TECompilerNode : public Object { virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; /*! * \brief Lower the external function using external codegen tools. - * \return The runtime moduels for each needed external codegen tool. + * \return The runtime modules for each needed external codegen tool. */ virtual tvm::Array LowerExternalFunctions() = 0; @@ -137,23 +138,6 @@ class TECompiler : public ObjectRef { using ContainerType = TECompilerNode; }; -/*! \brief The result of lowering a module, for now we need to pass an aggregate data structure - * which contains more then a single module in order to interact with the today API. - */ -struct LoweredModule { - /*! \brief The module which contains the Relay code. */ - IRModule main_module; - /*! \brief The module which contains per target code. */ - Map per_target_module; - /*! \brief The external runtime modules which must be combined with the lowered code. */ - Array external_mods; - // TODO(@electriclilies): THis might need to become a map - /*! \brief The info for this function (not sure what a better description is??) - * - */ - backend::FunctionInfo main_func_info; -}; - /*! * \brief A function to create the function metadata for an input function (ie calculate buffer * input/output sizes) @@ -174,27 +158,19 @@ void UpdateFunctionMetadata(Function relay_func, */ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); -/*! \brief Utility to convert a LoweredModule to an IRModule. +/*! \brief Utility to separate the the functions in an IRModule by Target. * - * This function takes all the target specific modules in LoweredModule and - * annotates their functions with the correct target, and puts all those functions - * in one IRModule. - * The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase. - * - * \param mod The LoweredModule to convert. - * \return The IRModule form of the input LoweredModule. + * \param mod The IRModule to extract the per target module from + * \return The map from Target to IRModule */ -IRModule LoweredModuleToIRModule(LoweredModule mod); +Map GetPerTargetModules(IRModule mod); -/*! \brief Utility to convert an IRModule to a LoweredModule. - * - * This function takes all the functions in the IRModule and moves them into target-specific - * IRModules stored inside a LoweredModule. - * The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase. - * \param mod The IRModule to convert. - * \return The LoweredModule form of the input IRModule. +/*! + * \brief Utility to extract all the Relay functions from an IRModule, with no PrimFuncs. + * \param mod The IRModule to extract the Relay functions from + * \return An IRModule containing only the Relay functions that are in the input mod (no PrimFuncs) */ -LoweredModule IRModuleToLoweredModule(IRModule mod); +IRModule GetMainModule(IRModule mod); /*! \brief Lower an IRModule's primitive functions to TIR. * @@ -210,7 +186,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod); * each function that we lower * \return The lowered module, see above. */ -LoweredModule LowerTE( +IRModule LowerTE( const IRModule& module, TargetMap targets, DeviceMap device_map, backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](Function f) {}); @@ -218,9 +194,8 @@ LoweredModule LowerTE( /*! \brief Pass to lower an IRModule's primitive functions to TIR. * * This is the "back half" of the Relay compiler which lowers "primitive functions" - * to TE expressions, schedules them, and then to TIR. This Pass calls LowerTE, and - * uses LoweredModuleToIRModule utility to convert the output LowerTE's output - * LoweredModule into an IRModule before returning it. + * to TE expressions, schedules them, and then to TIR. It annotates all functions + * with their target. * * \param targets The mapping for devices to targets. * \param device_context_map An analysis result mapping each sub-expression to a device. From 6317d8c31e6af40841d85e25c3fb4b0c2747fdf9 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 1 Sep 2021 10:27:10 -0700 Subject: [PATCH 2/7] Clean up some comments --- src/relay/backend/te_compiler.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 91a8fcb55dfd..a6bcb0d540ce 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -87,12 +87,9 @@ class TECompilerImpl : public TECompilerNode { IRModule GetLoweredFunctions() { IRModule mod; - // TODO(@electriclilies): This chunk of code is pretty much the same for - // normal cache and shape func cache. Consider making a helper here to do that. - // Additionaly, might be good to overhaul the mod->Update(mod) function (it's broken!) + // Extract lowered functions from the cache for (const auto& it : cache_) { auto source_func = it.first; - // TODO(@electriclilies): Does the lowered_func module only contain one function? auto lowered_func = it.second; IRModule lowered_mod = lowered_func->cached_func->funcs; @@ -102,6 +99,8 @@ class TECompilerImpl : public TECompilerNode { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; + // TODO(@electriclilies): There shouldn't be a Relay function in here. + // Figure out where it's coming from! if (func->IsInstance()) { const relay::Function relay_func = Downcast(func); mod->Update(var, WithAttr(relay_func, tvm::attr::kTarget, source_func->target)); @@ -115,13 +114,14 @@ class TECompilerImpl : public TECompilerNode { } } } - + // Extract lowered frunctions from the shape cache for (const auto& it : shape_func_cache_) { auto source_func = it.first; auto lowered_func = it.second; auto target = source_func->target; IRModule lowered_mod = lowered_func->cached_func->funcs; + // Annotate functions with their target and put them in the return module for (auto kv : lowered_mod->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; From 7da7cec0ad13f89145fe56b4090494b515b4111e Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 1 Sep 2021 12:01:23 -0700 Subject: [PATCH 3/7] QEMU flaky tests From d51edb402e78fdcc08fd09519e1accdc0e428630 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 1 Sep 2021 16:40:21 -0700 Subject: [PATCH 4/7] Don't add external functions to the LoweredFunctions module --- src/relay/backend/te_compiler.cc | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index a6bcb0d540ce..c7a3a50fe4bd 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -99,18 +99,13 @@ class TECompilerImpl : public TECompilerNode { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; - // TODO(@electriclilies): There shouldn't be a Relay function in here. - // Figure out where it's coming from! - if (func->IsInstance()) { - const relay::Function relay_func = Downcast(func); - mod->Update(var, WithAttr(relay_func, tvm::attr::kTarget, source_func->target)); - } else if (func->IsInstance()) { + // Only add functions that are not external functions + if (!func->GetAttr(attr::kCompiler).defined()) { + ICHECK(func->IsInstance()) + << "Expected all functions that are not external to be PrimFuncs, but found " + << func->GetTypeKey(); const tir::PrimFunc& prim_func = Downcast(func); mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target)); - } else { - LOG(FATAL) << "Expected to find only relay functions and prim functions in the cache, " - "but found: " - << func->GetTypeKey(); } } } @@ -883,7 +878,6 @@ IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_con } // Copy the lowered functions into the return module - std::cout << "Getting lowered funcs" << std::endl; updated_module->Update(compiler->GetLoweredFunctions()); // Annotate the module with the external modules and function info From 914275f5b882acea48be9d2a028e671354174811 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 1 Sep 2021 17:50:14 -0700 Subject: [PATCH 5/7] QEMU flaky test From ca53534504c2d8dfb654e361296a311c66490b9c Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 2 Sep 2021 11:28:13 -0700 Subject: [PATCH 6/7] Respond to feedback --- src/relay/backend/graph_executor_codegen.cc | 3 ++- src/relay/backend/te_compiler.cc | 4 +++- src/relay/backend/te_compiler.h | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 932947a5ffec..b7b388431ca1 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -277,8 +277,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> external_modules = lowered_mod->GetAttr>("external_mods"); - // This is the point where we separate the functions in the module by target + ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point."; + // This is the point where we separate the functions in the module by target ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); ret.external_mods = external_modules.value(); return ret; diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index c7a3a50fe4bd..9244e06c8c02 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -109,7 +109,7 @@ class TECompilerImpl : public TECompilerNode { } } } - // Extract lowered frunctions from the shape cache + // Extract lowered dynamic shape functions from the shape cache for (const auto& it : shape_func_cache_) { auto source_func = it.first; auto lowered_func = it.second; @@ -942,6 +942,8 @@ Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, PassContext ctx) { return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn); }; + // TODO(@electriclilies, mbs): Fold InferType() pass into LowerTEPass since it will always need to + // be called afterwards return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {}); } } // namespace tec diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 6555ababea8c..082cd8c4491a 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -158,7 +158,7 @@ void UpdateFunctionMetadata(Function relay_func, */ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); -/*! \brief Utility to separate the the functions in an IRModule by Target. +/*! \brief Utility to separate the functions in an IRModule by Target. * * \param mod The IRModule to extract the per target module from * \return The map from Target to IRModule From d4b2bd79b5d44dbda84a9388b74dd47be8208fb9 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 2 Sep 2021 16:12:13 -0700 Subject: [PATCH 7/7] flaky test