From 95943cf849fcdd6b1736857031f6a84589c72d90 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 14:19:11 +0900 Subject: [PATCH 01/19] New relay backend for meta schedule task extraction commit 501fac65291c51710911ca49af1577ea1794bcb2 Merge: 076fa33fc ce8c563d0 Author: Masahiro Masuda Date: Fri Mar 11 14:16:47 2022 +0900 New relay backend for meta schedule task extraction commit ce8c563d09eaba2a6b03189d1d3452f7565f4c69 Author: Masahiro Masuda Date: Fri Mar 11 14:12:30 2022 +0900 fix cpplint commit dfa4fb0c20c17049e8ac2c135200074b872ce1ec Author: Masahiro Masuda Date: Fri Mar 11 14:09:11 2022 +0900 update expected op list in test_meta_schedule_integration_extract_from_resnet to remove dep on Ansor commit a98182eed3b85e477c5f2527d5d21ce545bd5c18 Author: Masahiro Masuda Date: Fri Mar 11 13:56:35 2022 +0900 fixed test_meta_schedule_integration_apply_history_best commit 40d52a15b4c1ac9b8d4eac16f98ccec5e2a3e966 Author: Masahiro Masuda Date: Fri Mar 11 13:50:43 2022 +0900 uniquefy task names commit dfaf4964bf3a0b542ead5f11f356c2ec592be725 Author: Masahiro Masuda Date: Fri Mar 11 13:45:30 2022 +0900 dedup tasks commit e49d500299c9c884497410046421853266b60cd2 Author: Masahiro Masuda Date: Fri Mar 11 12:59:45 2022 +0900 return reversed list commit 74636beae0878cdda7dd03aa2b09ab2821c86477 Author: Masahiro Masuda Date: Fri Mar 11 12:39:58 2022 +0900 refactor commit 99f1701eb71d77a85bb0f8457841739dc586a168 Author: Masahiro Masuda Date: Fri Mar 11 12:34:14 2022 +0900 clean up integration.cc and Query interface commit 3f93a1e7645118c002aa10e5b7ff14b71b3f837a Author: Masahiro Masuda Date: Fri Mar 11 11:54:57 2022 +0900 check in minor vnni-related change commit af3e98867f91f99522fee4da2e170dc87311466c Author: Masahiro Masuda Date: Fri Mar 11 07:36:35 2022 +0900 Removed TaskExtraction node commit 7b4d35eb00852db6397d43e0aa6b1fedabae3f63 Author: Masahiro Masuda Date: Fri Mar 11 05:42:56 2022 +0900 add doc to util functions commit 3c5a3184fb42e69ef10619b05b9b9f128f7ea618 Author: Masahiro Masuda Date: Fri Mar 11 05:27:53 2022 +0900 rename to task extraction commit 57f2882a5ed5615ef8eee96cd7284d495f908449 Author: Masahiro Masuda Date: Fri Mar 11 05:24:37 2022 +0900 fixed constant param bind commit f099537d3630d268ad0700c75e93bbdc67831837 Author: Masahiro Masuda Date: Fri Mar 11 05:10:44 2022 +0900 remove unused stuff from python extract_tasks_from_relay commit 4a5e4aae48a7bdc8c24c8f7ae7bd5484034837e4 Author: Masahiro Masuda Date: Fri Mar 11 05:10:30 2022 +0900 move BindParams function to cc file commit efecceaea3958e184de7ef0ff6cb5f3988640afa Author: Masahiro Masuda Date: Fri Mar 11 03:56:05 2022 +0900 refactor param binding commit 109187fc0463728cd44171389e8fc91fb0ac8cf9 Author: Masahiro Masuda Date: Fri Mar 11 02:21:58 2022 +0900 New relay backend for meta schedule task extraction commit 6f019014a4614f43aefcf642981bfb15d64b09f3 Author: Masahiro Masuda Date: Fri Mar 11 11:25:44 2022 +0900 fixed anchor impl selection commit be6c25893dd0546db71b8472415303fc5be9d67f Author: Masahiro Masuda Date: Fri Mar 11 10:57:02 2022 +0900 Forgot visiting arg in ScheduleBuilder CallNode vsit commit 0c6d4a603335ae2cba2771e939eff1ddeb98fbe3 Author: Masahiro Masuda Date: Fri Mar 11 10:45:08 2022 +0900 add public, fix include path convention commit 4cd3a1657c4e2e13abe7281b7cdef5dff73b37ee Author: Masahiro Masuda Date: Thu Mar 10 18:43:15 2022 +0900 removed create_schedule stuff commit eb1bc7e789b66eaf3d4fe01d5154c135ab275dc2 Author: Masahiro Masuda Date: Thu Mar 10 18:13:42 2022 +0900 fixed merge conflict commit 6e68fd9aff9f86412f8b7150b18ae1b374927f86 Author: Masahiro Masuda Date: Thu Mar 10 14:27:34 2022 +0900 Decouple TE compute and schedule lowering in ScheduleBuilder --- include/tvm/meta_schedule/integration.h | 38 +-------- python/tvm/meta_schedule/integration.py | 65 +++++--------- python/tvm/topi/x86/batch_matmul.py | 1 + python/tvm/topi/x86/dense.py | 1 + src/meta_schedule/integration.cc | 62 ++++++-------- src/relay/backend/build_module.cc | 9 +- src/relay/backend/task_extraction.cc | 84 +++++++++++++++++++ src/relay/backend/te_compiler_cache.cc | 36 ++++++-- src/relay/backend/te_compiler_cache.h | 11 +++ src/relay/backend/utils.cc | 50 +++++++++++ src/relay/backend/utils.h | 40 +++------ src/relay/backend/vm/compiler.cc | 9 +- .../test_meta_schedule_integration.py | 83 ++++-------------- 13 files changed, 254 insertions(+), 235 deletions(-) create mode 100644 src/relay/backend/task_extraction.cc diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h index 9a8699b2fab9..3140b4f981e5 100644 --- a/include/tvm/meta_schedule/integration.h +++ b/include/tvm/meta_schedule/integration.h @@ -92,7 +92,7 @@ class MetaScheduleContextNode : public runtime::Object { * 3) relay::Function if `mod` should be dispatched to BYOC workflow * 4) IRModule for unified dispatch */ - virtual Optional Query(runtime::String task_name, IRModule mod, Target target, + virtual IRModule Query(runtime::String task_name, IRModule mod, Target target, Optional> dispatched) = 0; static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext"; @@ -129,7 +129,7 @@ class MetaScheduleContext : public runtime::ObjectRef { * 3) relay::Function if `mod` should be dispatched to BYOC workflow * 4) IRModule for unified dispatch */ - static Optional QueryInsideWithScope(runtime::String task_name, IRModule mod, + static IRModule QueryInsideWithScope(runtime::String task_name, IRModule mod, Target target, Optional> dispatched); @@ -145,38 +145,6 @@ class MetaScheduleContext : public runtime::ObjectRef { void ExitWithScope(); }; -/**************** TaskExtraction ****************/ - -/*! - * \brief An integration context for task extraction - */ -class TaskExtractionNode : public MetaScheduleContextNode { - public: - /*! \brief The extracted tasks */ - Array tasks{nullptr}; - - void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); } - - // Inherited from base class - Optional Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) final; - - static constexpr const char* _type_key = "meta_schedule.TaskExtraction"; - TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, MetaScheduleContextNode); -}; - -/*! - * \brief Managed reference to TaskExtractionNode - * \sa TaskExtractionNode - */ -class TaskExtraction : public MetaScheduleContext { - public: - /*! \brief The path to a cache file storing extracted tasks */ - TaskExtraction(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, MetaScheduleContext, - TaskExtractionNode); -}; - /**************** ApplyHistoryBest ****************/ /*! @@ -193,7 +161,7 @@ class ApplyHistoryBestNode : public MetaScheduleContextNode { } // Inherited from base class - Optional Query(runtime::String task_name, IRModule mod, Target target, + IRModule Query(runtime::String task_name, IRModule mod, Target target, Optional> dispatched) final; static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 26b01444e752..eebc2429acdf 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -18,9 +18,12 @@ from contextlib import contextmanager from typing import Callable, Dict, List, Optional, Union -from tvm._ffi import register_object +import numpy as np +import tvm.runtime.ndarray as nd + +from tvm._ffi import register_object, get_global_func from tvm.ir import IRModule, transform -from tvm.relay import Any +from tvm.relay import Any, const from tvm.relay import Function as RelayFunc from tvm.relay import vm from tvm.runtime import NDArray, Object @@ -176,17 +179,6 @@ def __exit__(self, ptype, value, trace) -> None: _ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member -@register_object("meta_schedule.TaskExtraction") -class TaskExtraction(MetaScheduleContext): - """An integration context for task extraction""" - - tasks: List[ExtractedTask] - """The extracted tasks""" - - def __init__(self) -> None: - self.__init_handle_by_constructor__(_ffi_api.TaskExtraction) # type: ignore # pylint: disable=no-member - - @register_object("meta_schedule.ApplyHistoryBest") class ApplyHistoryBest(MetaScheduleContext): """An integration context that allows application of historically best record from database""" @@ -230,45 +222,30 @@ def extract_task_from_relay( The tasks extracted from this network """ - @contextmanager - def _autotvm_silencer(): - from tvm import autotvm # pylint: disable=import-outside-toplevel - - silent = autotvm.GLOBAL_SCOPE.silent - autotvm.GLOBAL_SCOPE.silent = True - try: - yield - finally: - autotvm.GLOBAL_SCOPE.silent = silent + extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask") + assert extract_task_func - def _thread_run(func: Callable[[], None]) -> None: - import threading # pylint: disable=import-outside-toplevel + target = Target(target) if isinstance(target, str) else target - thread = threading.Thread(target=func) - thread.start() - thread.join() + relay_params = {} + for name, param in params.items(): + if isinstance(param, np.ndarray): + param = nd.array(param) + relay_params[name] = const(param) if disabled_pass is None: disabled_pass = [] - if pass_config is None: - pass_config = {"relay.backend.use_meta_schedule": True} - env = TaskExtraction() if isinstance(mod, RelayFunc): mod = IRModule.from_expr(mod) if not isinstance(target, Target): target = Target(target) - def _func(): - with env, _autotvm_silencer(), transform.PassContext( - config=pass_config, - disabled_pass=disabled_pass, - opt_level=opt_level, - ): - compiler = vm.VMCompiler() - if params: - compiler.set_params(params) - compiler.lower(mod, target) - - _thread_run(_func) - return env.tasks + with target, transform.PassContext( + opt_level=opt_level, + config=pass_config, + disabled_pass=disabled_pass, + ): + tasks = extract_task_func(mod, target, relay_params) + # Tasks are extracted via post order visit, return the reversed list. + return list(reversed(tasks)) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index b446c1f0115c..2d32bfe8f096 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -47,6 +47,7 @@ def batch_matmul_vnni_compute(cfg, x, y): axis=ak, ), tag="batch_matmul_vnni", + attrs={"schedule_rule": "batch_matmul_vnni"}, ) _, a_y, _ = z.op.axis diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index c8574b971003..cd6350352d98 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -296,6 +296,7 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None): axis=ak, ), tag="dense_vnni", + attrs={"schedule_rule": "dense_vnni"}, ) if bias is not None: diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 4f9055bf5bba..d2cb4b307bbf 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -26,9 +26,8 @@ namespace tvm { namespace meta_schedule { /**************** Utility functions ****************/ - -template -Optional GetOnlyOneFunction(const IRModule& mod) { +template +Optional GetOnlyOneFunctionCommon(const IRModule& mod, Callback on_found) { if (mod->functions.size() != 1) { return NullOpt; } @@ -37,12 +36,23 @@ Optional GetOnlyOneFunction(const IRModule& mod) { if (!func->IsInstance()) { return NullOpt; } else { - return Downcast(func); + return on_found(kv); } } return NullOpt; } +template +Optional GetOnlyOneFunctionKey(const IRModule& mod) { + return GetOnlyOneFunctionCommon(mod, [](auto kv) { return kv.first; }); +} + +template +Optional GetOnlyOneFunction(const IRModule& mod) { + return GetOnlyOneFunctionCommon( + mod, [](auto kv) { return Downcast(kv.second); }); +} + template bool HasOnlyOneFunction(const IRModule& mod) { return GetOnlyOneFunction(mod).defined(); @@ -86,31 +96,13 @@ void MetaScheduleContext::ExitWithScope() { ctx = NullOpt; } -Optional MetaScheduleContext::QueryInsideWithScope( - runtime::String task_name, IRModule mod, Target target, Optional> dispatched) { +IRModule MetaScheduleContext::QueryInsideWithScope(runtime::String task_name, IRModule mod, + Target target, + Optional> dispatched) { if (Optional ctx = MetaScheduleContext::Current()) { return ctx.value()->Query(task_name, mod, target, dispatched); } - return NullOpt; -} - -/**************** TaskExtraction ****************/ - -TaskExtraction::TaskExtraction() { - ObjectPtr n = make_object(); - n->tasks = Array(); - data_ = n; -} - -Optional TaskExtractionNode::Query(runtime::String task_name, IRModule mod, - Target target, Optional> dispatched) { - ICHECK(dispatched.defined()); - ICHECK_EQ(dispatched.value().size(), 1); - IRModule prim_mod = dispatched.value()[0]; - ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; - ICHECK(HasOnlyOneFunction(mod)) << mod; - tasks.push_back(ExtractedTask(task_name, mod, target, {prim_mod})); - return NullOpt; + return IRModule{nullptr}; } /**************** ApplyHistoryBest ****************/ @@ -121,14 +113,18 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) { data_ = n; } -Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, - Target target, - Optional> dispatched) { +IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Target target, + Optional> dispatched) { ICHECK(dispatched.defined()); ICHECK_EQ(dispatched.value().size(), 1); ICHECK(HasOnlyOneFunction(mod)) << mod; IRModule prim_mod = dispatched.value()[0]; ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + // TODO(masahi): parse_mod below replaces the orginal function key with "main". + // This is necessary because some scheduling primitives requires the PrimFunc key be "main". + // If we can remove this restriction, there would no need for GetOnlyOneFunction* calls below + // and we can directly return sch->mod(). + auto gv = GetOnlyOneFunctionKey(prim_mod).value(); // Unify func name to make sure it can be found in database const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod"); ICHECK(parse_mod_func) << "Parse mod function not defined!"; @@ -141,11 +137,11 @@ Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRMod /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); records[0]->trace->ApplyToSchedule(sch, false); tir::PrimFunc func = GetOnlyOneFunction(sch->mod()).value(); - return func; + return IRModule({{gv, func}}); } } LOG(WARNING) << "Cannot find workload: " << task_name << "\n" << tir::AsTVMScript(prim_mod); - return NullOpt; + return IRModule{nullptr}; } /**************** FFI ****************/ @@ -158,7 +154,6 @@ class MetaScheduleContextInternal { TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode); -TVM_REGISTER_NODE_TYPE(TaskExtractionNode); TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") @@ -176,9 +171,6 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQueryInsideWithScope") .set_body_typed(MetaScheduleContext::QueryInsideWithScope); TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") .set_body_method(&MetaScheduleContextNode::Query); -TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { - return TaskExtraction(); -}); TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") .set_body_typed([](Database database) -> ApplyHistoryBest { return ApplyHistoryBest(database); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 89ee61c83f7c..87fe39c389f0 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -333,14 +333,7 @@ class RelayBuildModule : public runtime::ModuleNode { IRModule OptimizeImpl(IRModule relay_module) { ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler."; - if (!params_.empty()) { - ICHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function"; - GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); - Function main_func = Downcast(relay_module->Lookup(main_glb_var)); - auto new_main = BindParamsByName(main_func, params_); - IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite(); - relay_module_ptr->Update(main_glb_var, new_main); - } + backend::BindParamsInModule(relay_module, params_); Array pass_seqs = GetPassPrefix( /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc new file mode 100644 index 000000000000..62f103e22a46 --- /dev/null +++ b/src/relay/backend/task_extraction.cc @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../../te/operation/create_primfunc.h" +#include "te_compiler_cache.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace backend { + +namespace metaschedule { + +using meta_schedule::ExtractedTask; + +Array ExtractTask(IRModule mod, Target target, Map params) { + backend::BindParamsInModule(mod, params); + + // is_vm=true for backward compatibility + Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); + pass_seqs.push_back(transform::FuseOps()); + + transform::Sequential seq(pass_seqs); + auto opt_mod = seq(std::move(mod)); + + Array tasks; + std::unordered_set cache_; + std::unordered_map name_map; + + PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache_, &name_map](const Expr& exp) { + if (exp->IsInstance()) { + Function relay_func = Downcast(exp); + tec::CCacheKey cache_key(relay_func, target); + if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache_.find(cache_key) == cache_.end()) { + Array outputs; + std::string fused_name; + std::tie(outputs, fused_name) = + tec::LowerTECompute(relay_func, target, /*return_inputs*/ true); + auto prim_func = tir::CreatePrimFunc(outputs); + auto prim_fn_var = GlobalVar(fused_name); + auto relay_mod = IRModule({{prim_fn_var, relay_func}}); + auto tir_mod = IRModule({{prim_fn_var, prim_func}}); + auto task_name = tec::GetUniqueName(fused_name, &name_map); + tasks.push_back(ExtractedTask(task_name, relay_mod, target, {tir_mod})); + cache_.insert(cache_key); + } + } + }); + + return tasks; +} + +} // namespace metaschedule + +TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask") + .set_body_typed([](IRModule mod, Target target, Map params) { + return metaschedule::ExtractTask(mod, target, params); + }); + +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index ffcce6e1c8da..b05f55099f4e 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -47,6 +47,7 @@ #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" +#include "tvm/runtime/object.h" #include "utils.h" namespace tvm { @@ -335,15 +336,14 @@ class ScheduleBuilder : public ExprVisitor { } } if (backend::IsMetaScheduleEnabled()) { - prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); - Optional opt_mod_or_base_func = - meta_schedule::MetaScheduleContext::QueryInsideWithScope( - prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, - Array{IRModule({{prim_fn_var, prim_func}})}); - if (const auto* result = opt_mod_or_base_func.as()) { - prim_func = GetRef(result); - } else { - prim_func = tir::PrimFunc(nullptr); + auto relay_mod = IRModule({{prim_fn_var, relay_func}}); + auto tir_mod = + IRModule({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}}); + IRModule scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope( + prim_fn_var->name_hint, relay_mod, target_, Array{tir_mod}); + if (scheduled_mod.defined()) { + ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1); + prim_func = Downcast(scheduled_mod->functions[prim_fn_var]); } } @@ -754,6 +754,24 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, return MakeShapeFunc().Create(prim_func, target, renamer); } +std::pair, std::string> LowerTECompute(const Function& source_func, Target target, + bool return_inputs) { + LowerToTECompute lower_te_compute(target); + auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; }); + // Following ScheduleBuilder, remove placeholder ops from outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + if (return_inputs) { + return std::make_pair(Concat(lower_te_compute.fn_inputs_, tensor_outs), + lower_te_compute.candidate_name_); + } + return std::make_pair(tensor_outs, lower_te_compute.candidate_name_); +} + /*! * \brief Get unique name from name. * \param name The orginal name. diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 2ffca1aa6be7..55f221ac8ba0 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -37,6 +37,7 @@ #include #include #include +#include #include "../transforms/infer_layout_utils.h" @@ -204,6 +205,16 @@ class CCacheValue : public ObjectRef { Array GetShape(const Array& shape); +/*! + * \brief Lowers Relay primitive Function to TE Compute + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \param return_inputs If true, prepend input tensors to the output array of tensors. + * \return Pair of schedule and fused function name. + */ +std::pair, std::string> LowerTECompute(const Function& source_func, Target target, + bool return_inputs = true); + /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 7662018e4f71..9883fe85c253 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -308,6 +308,56 @@ std::vector ShapeToJSON(tvm::Array shape) { return ret; } +relay::Function BindParamsByName(relay::Function func, + const std::unordered_map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto& name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(name_dict[name]); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map bind_dict; + for (auto& kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = Constant(kv.second); + } + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; +} + +void BindParamsInModule(IRModule mod, + const std::unordered_map& params) { + if (!params.empty()) { + BaseFunc base_func = mod->Lookup("main"); + ICHECK(base_func->IsInstance()); + auto f = relay::backend::BindParamsByName(Downcast(base_func), params); + auto gvar = mod->GetGlobalVar("main"); + mod->Add(gvar, f); + } +} + +void BindParamsInModule(IRModule mod, Map params) { + std::unordered_map params_tmp; + for (const auto& kv : params) { + params_tmp[kv.first] = kv.second->data; + } + BindParamsInModule(mod, params_tmp); +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 3b4d4c18de89..1a39d4330d45 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -386,36 +386,18 @@ inline std::string DType2String(const tvm::DataType dtype) { * \param params params dict * \return relay::Function */ -inline relay::Function BindParamsByName( - relay::Function func, const std::unordered_map& params) { - std::unordered_map name_dict; - std::unordered_set repeat_var; - for (auto arg : func->params) { - const auto& name = arg->name_hint(); - if (name_dict.count(name)) { - repeat_var.insert(name_dict[name]); - } else { - name_dict[name] = arg; - } - } +relay::Function BindParamsByName(relay::Function func, + const std::unordered_map& params); - std::unordered_map bind_dict; - for (auto& kv : params) { - if (name_dict.count(kv.first) == 0) { - continue; - } - auto arg = name_dict.at(kv.first); - if (repeat_var.count(arg)) { - LOG(FATAL) << "Multiple args in the function have name " << kv.first; - } - bind_dict[arg] = Constant(kv.second); - } - Expr bound_expr = relay::Bind(func, bind_dict); - Function ret = Downcast(bound_expr); - ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function." - << "\n"; - return ret; -} +/*! + * \brief Bind params to the main function in Relay module, using BindParamsByName + * \param mod Relay module + * \param params params dict + */ +void BindParamsInModule(IRModule mod, + const std::unordered_map& params); + +void BindParamsInModule(IRModule mod, Map params); /*! * \brief Extract the shape from a Relay tensor type. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e94919de7f20..130fb09e7af1 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1034,14 +1034,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets, IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { VLOG_CONTEXT << "VM Optimize"; - if (params_.size()) { - BaseFunc base_func = mod->Lookup("main"); - ICHECK(base_func->IsInstance()) - << "VM compiler expects to compile relay::Function"; - auto f = relay::backend::BindParamsByName(Downcast(base_func), params_); - auto gvar = mod->GetGlobalVar("main"); - mod->Add(gvar, f); - } + backend::BindParamsInModule(mod, params_); Array pass_seqs = relay::backend::GetPassPrefix( /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true); diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 50dc9289780d..1c3ef8ae0e19 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -26,7 +26,6 @@ ApplyHistoryBest, ExtractedTask, MetaScheduleContext, - TaskExtraction, ) from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.utils import derived_object @@ -63,80 +62,31 @@ def _has_torch(): requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") -def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule): - (task,) = tasks - assert isinstance(task, ExtractedTask) - assert task.task_name == "mock-task" - tvm.ir.assert_structural_equal(task.mod, mod) - (tir_mod,) = task.dispatched - tvm.ir.assert_structural_equal(tir_mod, MockModule) - - -@requires_torch -def test_meta_schedule_integration_task_extraction_query(): - mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - env = TaskExtraction() - env.query(task_name="mock-task", mod=mod, target=Target("llvm"), dispatched=[MockModule]) - _check_mock_task(env.tasks, mod) - - -def test_meta_schedule_integration_current(): - env = TaskExtraction() - with env: - assert MetaScheduleContext.current() == env - - -def test_meta_schedule_integration_no_current(): - assert MetaScheduleContext.current() is None - - -def test_meta_schedule_integration_multiple_current(): - env = TaskExtraction() - with env: - with pytest.raises(ValueError): - with env: - ... - - -@requires_torch -def test_meta_schedule_integration_query_inside_with_scope(): - mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - env = TaskExtraction() - with env: - MetaScheduleContext.query_inside_with_scope( - task_name="mock-task", - mod=mod, - target=Target("llvm"), - dispatched=[MockModule], - ) - _check_mock_task(env.tasks, mod) - - @requires_torch def test_meta_schedule_integration_extract_from_resnet(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) expected_task_names = [ - "vm_mod_fused_" + s + "fused_" + s for s in [ - "nn_max_pool2d", - "nn_adaptive_avg_pool2d", - "nn_dense_add", - "nn_conv2d_add", - "nn_conv2d_add_1", "nn_conv2d_add_2", - "nn_conv2d_add_add_nn_relu", + "nn_conv2d_add_1", + "nn_conv2d_add", + "nn_conv2d_add_nn_relu_7", + "nn_max_pool2d", + "nn_conv2d_add_nn_relu_6", + "nn_conv2d_add_add_nn_relu_3", + "nn_conv2d_add_nn_relu_5", + "nn_conv2d_add_nn_relu_4", + "nn_conv2d_add_add_nn_relu_2", + "nn_conv2d_add_nn_relu_3", + "nn_conv2d_add_nn_relu_2", "nn_conv2d_add_add_nn_relu_1", - "nn_conv2d_add_nn_relu", "nn_conv2d_add_nn_relu_1", - "nn_conv2d_add_nn_relu_2", - "nn_conv2d_add_nn_relu_3", - "nn_conv2d_add_nn_relu_4", - "nn_conv2d_add_nn_relu_5", - "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu", - "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1", - "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu", - "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1", + "nn_conv2d_add_nn_relu", + "nn_conv2d_add_add_nn_relu", + "nn_adaptive_avg_pool2d", + "nn_contrib_dense_pack_add", # The two tasks below are purely spatial and are ruled out by AutoScheduler "layout_transform", "layout_transform_reshape_squeeze", @@ -197,7 +147,6 @@ def print_results(self) -> None: TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []) ) mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule]) - mod = IRModule({"main": mod}) assert tvm.ir.structural_equal(mod, workload.mod) From 6070fe7b967cea1a3005946a5f636af80142f0da Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 14:29:15 +0900 Subject: [PATCH 02/19] update integration.h doc --- include/tvm/meta_schedule/integration.h | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h index 3140b4f981e5..e66423545e75 100644 --- a/include/tvm/meta_schedule/integration.h +++ b/include/tvm/meta_schedule/integration.h @@ -86,14 +86,11 @@ class MetaScheduleContextNode : public runtime::Object { * \param target Target info * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to. * NullOpt means the dispatch needs to be done in the context. - * \return There are different types of the output - * 1) NullOpt if there is no feedback hint - * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc - * 3) relay::Function if `mod` should be dispatched to BYOC workflow - * 4) IRModule for unified dispatch + * \return IRModule Currently we only have to return tir::PrimFunc, but we wrap it + * under IRModule for more general future use. */ virtual IRModule Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) = 0; + Optional> dispatched) = 0; static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext"; TVM_DECLARE_BASE_OBJECT_INFO(MetaScheduleContextNode, runtime::Object); @@ -123,15 +120,11 @@ class MetaScheduleContext : public runtime::ObjectRef { * \param mod The high-level IR * \param target Target info * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to - * \return There are different types of the output - * 1) NullOpt if there is no feedback hint - * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc - * 3) relay::Function if `mod` should be dispatched to BYOC workflow - * 4) IRModule for unified dispatch + * \return IRModule Currently we only have to return tir::PrimFunc, but we wrap it + * under IRModule for more general future use. */ - static IRModule QueryInsideWithScope(runtime::String task_name, IRModule mod, - Target target, - Optional> dispatched); + static IRModule QueryInsideWithScope(runtime::String task_name, IRModule mod, Target target, + Optional> dispatched); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef, MetaScheduleContextNode); @@ -162,7 +155,7 @@ class ApplyHistoryBestNode : public MetaScheduleContextNode { // Inherited from base class IRModule Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) final; + Optional> dispatched) final; static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, MetaScheduleContextNode); From e6b9fb88fef81ce85455770d88ab83d5ac501816 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 15:55:26 +0900 Subject: [PATCH 03/19] remove unused import --- python/tvm/meta_schedule/integration.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index eebc2429acdf..e07bb89ede7b 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. """Meta schedule integration with high-level IR""" -from contextlib import contextmanager -from typing import Callable, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import tvm.runtime.ndarray as nd @@ -25,7 +24,6 @@ from tvm.ir import IRModule, transform from tvm.relay import Any, const from tvm.relay import Function as RelayFunc -from tvm.relay import vm from tvm.runtime import NDArray, Object from tvm.target import Target from tvm.tir import PrimFunc From ab27ddec006e5d14b66c068e6811e2132701ac3b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 16:25:17 +0900 Subject: [PATCH 04/19] fix mypy check --- python/tvm/meta_schedule/integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index e07bb89ede7b..49a7138427b8 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -17,7 +17,7 @@ """Meta schedule integration with high-level IR""" from typing import Dict, List, Optional, Union -import numpy as np +import numpy as np # type: ignore import tvm.runtime.ndarray as nd from tvm._ffi import register_object, get_global_func From 4ff0c25a82becc082580e13991ea219efa3f98fd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 18:43:51 +0900 Subject: [PATCH 05/19] use_meta_schedule restored, now extracts the same task as Ansor --- python/tvm/meta_schedule/integration.py | 2 ++ .../test_meta_schedule_integration.py | 34 +++++++++++-------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 49a7138427b8..c3f691a9422f 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -233,6 +233,8 @@ def extract_task_from_relay( if disabled_pass is None: disabled_pass = [] + if pass_config is None: + pass_config = {"relay.backend.use_meta_schedule": True} if isinstance(mod, RelayFunc): mod = IRModule.from_expr(mod) diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 1c3ef8ae0e19..4620e83d8ec4 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -62,6 +62,10 @@ def _has_torch(): requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") +def test_meta_schedule_integration_no_current(): + assert MetaScheduleContext.current() is None + + @requires_torch def test_meta_schedule_integration_extract_from_resnet(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) @@ -69,24 +73,24 @@ def test_meta_schedule_integration_extract_from_resnet(): expected_task_names = [ "fused_" + s for s in [ - "nn_conv2d_add_2", - "nn_conv2d_add_1", - "nn_conv2d_add", - "nn_conv2d_add_nn_relu_7", "nn_max_pool2d", - "nn_conv2d_add_nn_relu_6", - "nn_conv2d_add_add_nn_relu_3", - "nn_conv2d_add_nn_relu_5", - "nn_conv2d_add_nn_relu_4", - "nn_conv2d_add_add_nn_relu_2", - "nn_conv2d_add_nn_relu_3", - "nn_conv2d_add_nn_relu_2", + "nn_adaptive_avg_pool2d", + "nn_dense_add", + "nn_conv2d_add", + "nn_conv2d_add_1", + "nn_conv2d_add_2", + "nn_conv2d_add_add_nn_relu", "nn_conv2d_add_add_nn_relu_1", - "nn_conv2d_add_nn_relu_1", "nn_conv2d_add_nn_relu", - "nn_conv2d_add_add_nn_relu", - "nn_adaptive_avg_pool2d", - "nn_contrib_dense_pack_add", + "nn_conv2d_add_nn_relu_1", + "nn_conv2d_add_nn_relu_2", + "nn_conv2d_add_nn_relu_3", + "nn_conv2d_add_nn_relu_4", + "nn_conv2d_add_nn_relu_5", + "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu", + "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1", # The two tasks below are purely spatial and are ruled out by AutoScheduler "layout_transform", "layout_transform_reshape_squeeze", From a66b8b513e6a549533859706e386e2058308d011 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 19:02:15 +0900 Subject: [PATCH 06/19] python doc update --- python/tvm/meta_schedule/integration.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index c3f691a9422f..4d92076b26f6 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -78,7 +78,7 @@ def query( mod: IRModule, target: Target, dispatched: Optional[List[IRModule]], - ) -> Union[IRModule, RelayFunc, PrimFunc, None]: + ) -> IRModule: """The entry point of the integration Parameters @@ -94,12 +94,9 @@ def query( Returns ------- - result : Union[IRModule, RelayFunc, PrimFunc, None] - There are different types of the output: - 1) NullOpt if there is no feedback hint; - 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc; - 3) relay::Function if `mod` should be dispatched to BYOC workflow; - 4) IRModule for unified dispatch + result : IRModule + Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for + more general future use """ return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member self, @@ -127,7 +124,7 @@ def query_inside_with_scope( mod: IRModule, target: Target, dispatched: Optional[List[IRModule]], - ) -> Union[IRModule, RelayFunc, PrimFunc, None]: + ) -> IRModule: """The entry point of the integration workflow. The compilation process of the high-level IR should call this method for task extraction and for feedback hints @@ -138,7 +135,7 @@ def query_inside_with_scope( def query_inside_with_scope(task_name, mod, dispatched): ctx = MetaScheduleContext.current() assert ctx is not None - ctx.query(task_name, mod, target, dispatched) + mod = ctx.query(task_name, mod, target, dispatched) Parameters ---------- @@ -153,12 +150,9 @@ def query_inside_with_scope(task_name, mod, dispatched): Returns ------- - result : Union[IRModule, RelayFunc, PrimFunc, None] - There are different types of the output: - 1) NullOpt if there is no feedback hint; - 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc; - 3) relay::Function if `mod` should be dispatched to BYOC workflow; - 4) IRModule for unified dispatch + result : IRModule + Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for + more general future use """ return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member task_name, From bc1c9895019098041aaf2594a2f40d6d88dad209 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 19:18:54 +0900 Subject: [PATCH 07/19] unused import --- python/tvm/meta_schedule/integration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 4d92076b26f6..3d3a1527e3e8 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -26,7 +26,6 @@ from tvm.relay import Function as RelayFunc from tvm.runtime import NDArray, Object from tvm.target import Target -from tvm.tir import PrimFunc from . import _ffi_api from .database import Database From 5a1c93de7d75b5f38474d70c9571f58dbde281bb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 14 Mar 2022 06:34:53 +0900 Subject: [PATCH 08/19] cache_ -> cache, suppres "Cannot find workdload" warning --- src/meta_schedule/integration.cc | 3 ++- src/relay/backend/task_extraction.cc | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index d2cb4b307bbf..2a870ca61157 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -140,7 +140,8 @@ IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Ta return IRModule({{gv, func}}); } } - LOG(WARNING) << "Cannot find workload: " << task_name << "\n" << tir::AsTVMScript(prim_mod); + LOG(WARNING) << "Cannot find workload: " << task_name; + DLOG(INFO) << tir::AsTVMScript(prim_mod); return IRModule{nullptr}; } diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 62f103e22a46..d7757c31bed8 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -46,14 +46,14 @@ Array ExtractTask(IRModule mod, Target target, Map tasks; - std::unordered_set cache_; + std::unordered_set cache; std::unordered_map name_map; - PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache_, &name_map](const Expr& exp) { + PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache, &name_map](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); tec::CCacheKey cache_key(relay_func, target); - if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache_.find(cache_key) == cache_.end()) { + if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache.find(cache_key) == cache.end()) { Array outputs; std::string fused_name; std::tie(outputs, fused_name) = @@ -64,7 +64,7 @@ Array ExtractTask(IRModule mod, Target target, Map Date: Tue, 15 Mar 2022 12:24:16 +0900 Subject: [PATCH 09/19] Update src/relay/backend/task_extraction.cc and te_compiler_cache.cc Co-authored-by: Junru Shao --- src/meta_schedule/integration.cc | 2 +- src/relay/backend/task_extraction.cc | 6 +++--- src/relay/backend/te_compiler_cache.cc | 7 +++---- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 2a870ca61157..7311dd5b2160 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -124,7 +124,7 @@ IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Ta // This is necessary because some scheduling primitives requires the PrimFunc key be "main". // If we can remove this restriction, there would no need for GetOnlyOneFunction* calls below // and we can directly return sch->mod(). - auto gv = GetOnlyOneFunctionKey(prim_mod).value(); + GlobalVar gv = GetOnlyOneFunctionKey(prim_mod).value(); // Unify func name to make sure it can be found in database const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod"); ICHECK(parse_mod_func) << "Parse mod function not defined!"; diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index d7757c31bed8..36c242499836 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -54,10 +54,10 @@ Array ExtractTask(IRModule mod, Target target, Map(exp); tec::CCacheKey cache_key(relay_func, target); if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache.find(cache_key) == cache.end()) { - Array outputs; + Array inputs_outputs; std::string fused_name; - std::tie(outputs, fused_name) = - tec::LowerTECompute(relay_func, target, /*return_inputs*/ true); + std::tie(inputs_outputs, fused_name) = + tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); auto prim_func = tir::CreatePrimFunc(outputs); auto prim_fn_var = GlobalVar(fused_name); auto relay_mod = IRModule({{prim_fn_var, relay_func}}); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index b05f55099f4e..1d39626e4f23 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -336,9 +336,8 @@ class ScheduleBuilder : public ExprVisitor { } } if (backend::IsMetaScheduleEnabled()) { - auto relay_mod = IRModule({{prim_fn_var, relay_func}}); - auto tir_mod = - IRModule({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}}); + IRModule relay_mod({{prim_fn_var, relay_func}}); + IRModule tir_mod({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}}); IRModule scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope( prim_fn_var->name_hint, relay_mod, target_, Array{tir_mod}); if (scheduled_mod.defined()) { @@ -757,7 +756,7 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, std::pair, std::string> LowerTECompute(const Function& source_func, Target target, bool return_inputs) { LowerToTECompute lower_te_compute(target); - auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; }); + Array outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; }); // Following ScheduleBuilder, remove placeholder ops from outputs. tvm::Array tensor_outs; for (const auto& tensor : outputs) { From 560ca6962a277386d222170b27f1a07f82f326b3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Mar 2022 12:25:16 +0900 Subject: [PATCH 10/19] removed unnecessary include --- src/relay/backend/te_compiler_cache.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 1d39626e4f23..7569c4f22224 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -47,7 +47,6 @@ #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" -#include "tvm/runtime/object.h" #include "utils.h" namespace tvm { @@ -756,7 +755,8 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, std::pair, std::string> LowerTECompute(const Function& source_func, Target target, bool return_inputs) { LowerToTECompute lower_te_compute(target); - Array outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; }); + Array outputs = + lower_te_compute.Lower(source_func, [](std::string name) { return name; }); // Following ScheduleBuilder, remove placeholder ops from outputs. tvm::Array tensor_outs; for (const auto& tensor : outputs) { From 425e64be04efc15630062cea6697a5d45cf403ae Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Mar 2022 12:37:15 +0900 Subject: [PATCH 11/19] fixed build --- src/relay/backend/task_extraction.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 36c242499836..e4638f657aff 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -58,7 +58,7 @@ Array ExtractTask(IRModule mod, Target target, Map Date: Tue, 15 Mar 2022 12:42:59 +0900 Subject: [PATCH 12/19] drop relay.const on params --- python/tvm/meta_schedule/integration.py | 3 +-- src/relay/backend/task_extraction.cc | 6 ++++-- src/relay/backend/utils.cc | 5 +++-- src/relay/backend/utils.h | 3 ++- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 3d3a1527e3e8..9ea3e325972c 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -22,7 +22,7 @@ from tvm._ffi import register_object, get_global_func from tvm.ir import IRModule, transform -from tvm.relay import Any, const +from tvm.relay import Any from tvm.relay import Function as RelayFunc from tvm.runtime import NDArray, Object from tvm.target import Target @@ -222,7 +222,6 @@ def extract_task_from_relay( for name, param in params.items(): if isinstance(param, np.ndarray): param = nd.array(param) - relay_params[name] = const(param) if disabled_pass is None: disabled_pass = [] diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index e4638f657aff..c96a2c5564fa 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -25,6 +25,7 @@ #include "../../te/operation/create_primfunc.h" #include "te_compiler_cache.h" +#include "tvm/runtime/ndarray.h" #include "utils.h" namespace tvm { @@ -35,7 +36,8 @@ namespace metaschedule { using meta_schedule::ExtractedTask; -Array ExtractTask(IRModule mod, Target target, Map params) { +Array ExtractTask(IRModule mod, Target target, + Map params) { backend::BindParamsInModule(mod, params); // is_vm=true for backward compatibility @@ -75,7 +77,7 @@ Array ExtractTask(IRModule mod, Target target, Map params) { + .set_body_typed([](IRModule mod, Target target, Map params) { return metaschedule::ExtractTask(mod, target, params); }); diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 9883fe85c253..fd3ab64fcc1c 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -29,6 +29,7 @@ #include #include "te_compiler.h" +#include "tvm/runtime/ndarray.h" namespace tvm { namespace relay { @@ -350,10 +351,10 @@ void BindParamsInModule(IRModule mod, } } -void BindParamsInModule(IRModule mod, Map params) { +void BindParamsInModule(IRModule mod, Map params) { std::unordered_map params_tmp; for (const auto& kv : params) { - params_tmp[kv.first] = kv.second->data; + params_tmp[kv.first] = kv.second; } BindParamsInModule(mod, params_tmp); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 1a39d4330d45..37a89d3edced 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -46,6 +46,7 @@ #include "../../runtime/meta_data.h" #include "../../target/metadata.h" +#include "tvm/runtime/ndarray.h" namespace tvm { namespace relay { @@ -397,7 +398,7 @@ relay::Function BindParamsByName(relay::Function func, void BindParamsInModule(IRModule mod, const std::unordered_map& params); -void BindParamsInModule(IRModule mod, Map params); +void BindParamsInModule(IRModule mod, Map params); /*! * \brief Extract the shape from a Relay tensor type. From 93461363762940f51748761747474baa13a7c445 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Mar 2022 13:12:56 +0900 Subject: [PATCH 13/19] updated comment in integration.cc --- python/tvm/meta_schedule/integration.py | 1 + src/meta_schedule/integration.cc | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 9ea3e325972c..8a839f365308 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -222,6 +222,7 @@ def extract_task_from_relay( for name, param in params.items(): if isinstance(param, np.ndarray): param = nd.array(param) + relay_params[name] = param if disabled_pass is None: disabled_pass = [] diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 7311dd5b2160..c56c656c686a 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -120,15 +120,21 @@ IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Ta ICHECK(HasOnlyOneFunction(mod)) << mod; IRModule prim_mod = dispatched.value()[0]; ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + // TODO(masahi): parse_mod below replaces the orginal function key with "main". - // This is necessary because some scheduling primitives requires the PrimFunc key be "main". - // If we can remove this restriction, there would no need for GetOnlyOneFunction* calls below - // and we can directly return sch->mod(). + // This is necessary because, in practice, most use of the scheduling primitive "get_block" + // assumes that the function name is "main". + // If we do not have this requirement, we can remove the call to parse_mod and + // GetOnlyOneFunction* calls below and instead we can directly return sch->mod(). + + // Keep the original func name to be returned later. GlobalVar gv = GetOnlyOneFunctionKey(prim_mod).value(); + // Unify func name to make sure it can be found in database const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod"); ICHECK(parse_mod_func) << "Parse mod function not defined!"; prim_mod = (*parse_mod_func)(prim_mod); + if (database->HasWorkload(prim_mod)) { Array records = database->GetTopK(database->CommitWorkload(prim_mod), 1); if (records.size() == 1) { @@ -137,6 +143,7 @@ IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Ta /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); records[0]->trace->ApplyToSchedule(sch, false); tir::PrimFunc func = GetOnlyOneFunction(sch->mod()).value(); + // Make sure we return the updated PrimFunc paired with the original func name. return IRModule({{gv, func}}); } } From 3d148dedf5fe7581477d9d0aec5667a5639cd57a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Mar 2022 13:21:19 +0900 Subject: [PATCH 14/19] update schedule_rule name to prepend "metaschedule" --- python/tvm/topi/x86/batch_matmul.py | 2 +- python/tvm/topi/x86/dense.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 2d32bfe8f096..55453df95e7a 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -47,7 +47,7 @@ def batch_matmul_vnni_compute(cfg, x, y): axis=ak, ), tag="batch_matmul_vnni", - attrs={"schedule_rule": "batch_matmul_vnni"}, + attrs={"schedule_rule": "meta_schedule.batch_matmul_vnni"}, ) _, a_y, _ = z.op.axis diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index cd6350352d98..1e4ccb7bb8c8 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -296,7 +296,7 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None): axis=ak, ), tag="dense_vnni", - attrs={"schedule_rule": "dense_vnni"}, + attrs={"schedule_rule": "meta_schedule.dense_vnni"}, ) if bias is not None: From 56f38d47c134ffd19ee05410e8ddb3f16430dff9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Mar 2022 13:22:35 +0900 Subject: [PATCH 15/19] typo fix --- src/meta_schedule/integration.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index c56c656c686a..6a29030cf5f0 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -121,9 +121,9 @@ IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Ta IRModule prim_mod = dispatched.value()[0]; ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; - // TODO(masahi): parse_mod below replaces the orginal function key with "main". - // This is necessary because, in practice, most use of the scheduling primitive "get_block" - // assumes that the function name is "main". + // TODO(masahi): parse_mod below replaces the original function key with "main". + // This is necessary because, in practice, most uses of the scheduling primitive "get_block" + // assume that the function name is "main". // If we do not have this requirement, we can remove the call to parse_mod and // GetOnlyOneFunction* calls below and instead we can directly return sch->mod(). From 9ba000a5daf461d3b7d43304af8131ed1d4f725a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Mar 2022 13:25:54 +0900 Subject: [PATCH 16/19] more nit change --- src/relay/backend/task_extraction.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index c96a2c5564fa..898e76b81b98 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -61,9 +61,9 @@ Array ExtractTask(IRModule mod, Target target, std::tie(inputs_outputs, fused_name) = tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); auto prim_func = tir::CreatePrimFunc(inputs_outputs); - auto prim_fn_var = GlobalVar(fused_name); - auto relay_mod = IRModule({{prim_fn_var, relay_func}}); - auto tir_mod = IRModule({{prim_fn_var, prim_func}}); + GlobalVar prim_fn_var(fused_name); + IRModule relay_mod({{prim_fn_var, relay_func}}); + IRModule tir_mod({{prim_fn_var, prim_func}}); auto task_name = tec::GetUniqueName(fused_name, &name_map); tasks.push_back(ExtractedTask(task_name, relay_mod, target, {tir_mod})); cache.insert(cache_key); From af794545497765c333d8da42ce1a42a04375b7ba Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Mar 2022 14:51:49 +0900 Subject: [PATCH 17/19] make the output of Query Optional --- include/tvm/meta_schedule/integration.h | 23 +++++++++++++---------- src/meta_schedule/integration.cc | 16 +++++++++------- src/relay/backend/te_compiler_cache.cc | 8 ++++---- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h index e66423545e75..56d8d379df93 100644 --- a/include/tvm/meta_schedule/integration.h +++ b/include/tvm/meta_schedule/integration.h @@ -86,11 +86,12 @@ class MetaScheduleContextNode : public runtime::Object { * \param target Target info * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to. * NullOpt means the dispatch needs to be done in the context. - * \return IRModule Currently we only have to return tir::PrimFunc, but we wrap it - * under IRModule for more general future use. + * \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it + * under IRModule for more general future use. NullOpt is returned + * if there is no feedback hint. */ - virtual IRModule Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) = 0; + virtual Optional Query(runtime::String task_name, IRModule mod, Target target, + Optional> dispatched) = 0; static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext"; TVM_DECLARE_BASE_OBJECT_INFO(MetaScheduleContextNode, runtime::Object); @@ -120,11 +121,13 @@ class MetaScheduleContext : public runtime::ObjectRef { * \param mod The high-level IR * \param target Target info * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to - * \return IRModule Currently we only have to return tir::PrimFunc, but we wrap it - * under IRModule for more general future use. + * \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it + * under IRModule for more general future use. NullOpt is returned + * if there is no feedback hint */ - static IRModule QueryInsideWithScope(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched); + static Optional QueryInsideWithScope(runtime::String task_name, IRModule mod, + Target target, + Optional> dispatched); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef, MetaScheduleContextNode); @@ -154,8 +157,8 @@ class ApplyHistoryBestNode : public MetaScheduleContextNode { } // Inherited from base class - IRModule Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) final; + Optional Query(runtime::String task_name, IRModule mod, Target target, + Optional> dispatched) final; static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, MetaScheduleContextNode); diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 6a29030cf5f0..1ebd554d9323 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -21,6 +21,7 @@ #include #include "./utils.h" +#include "tvm/runtime/container/optional.h" namespace tvm { namespace meta_schedule { @@ -96,13 +97,13 @@ void MetaScheduleContext::ExitWithScope() { ctx = NullOpt; } -IRModule MetaScheduleContext::QueryInsideWithScope(runtime::String task_name, IRModule mod, - Target target, - Optional> dispatched) { +Optional MetaScheduleContext::QueryInsideWithScope(runtime::String task_name, + IRModule mod, Target target, + Optional> dispatched) { if (Optional ctx = MetaScheduleContext::Current()) { return ctx.value()->Query(task_name, mod, target, dispatched); } - return IRModule{nullptr}; + return NullOpt; } /**************** ApplyHistoryBest ****************/ @@ -113,8 +114,9 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) { data_ = n; } -IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) { +Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, + Target target, + Optional> dispatched) { ICHECK(dispatched.defined()); ICHECK_EQ(dispatched.value().size(), 1); ICHECK(HasOnlyOneFunction(mod)) << mod; @@ -149,7 +151,7 @@ IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Ta } LOG(WARNING) << "Cannot find workload: " << task_name; DLOG(INFO) << tir::AsTVMScript(prim_mod); - return IRModule{nullptr}; + return NullOpt; } /**************** FFI ****************/ diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 7569c4f22224..8b8a1e92f82c 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -337,11 +337,11 @@ class ScheduleBuilder : public ExprVisitor { if (backend::IsMetaScheduleEnabled()) { IRModule relay_mod({{prim_fn_var, relay_func}}); IRModule tir_mod({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}}); - IRModule scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope( + Optional scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope( prim_fn_var->name_hint, relay_mod, target_, Array{tir_mod}); - if (scheduled_mod.defined()) { - ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1); - prim_func = Downcast(scheduled_mod->functions[prim_fn_var]); + if (scheduled_mod) { + ICHECK_EQ(scheduled_mod.value()->functions.count(prim_fn_var), 1); + prim_func = Downcast(scheduled_mod.value()->functions[prim_fn_var]); } } From ff7f8270ac9a16df7034a12ec0b959c003feff6e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Mar 2022 14:54:53 +0900 Subject: [PATCH 18/19] update py doc --- python/tvm/meta_schedule/integration.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 8a839f365308..d9391d0d713f 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -77,7 +77,7 @@ def query( mod: IRModule, target: Target, dispatched: Optional[List[IRModule]], - ) -> IRModule: + ) -> Union[IRModule, None]: """The entry point of the integration Parameters @@ -93,9 +93,9 @@ def query( Returns ------- - result : IRModule + result : IRModule or None Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for - more general future use + more general future use. None is returned if there is no feedback hint. """ return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member self, @@ -123,7 +123,7 @@ def query_inside_with_scope( mod: IRModule, target: Target, dispatched: Optional[List[IRModule]], - ) -> IRModule: + ) -> Union[IRModule, None]: """The entry point of the integration workflow. The compilation process of the high-level IR should call this method for task extraction and for feedback hints @@ -149,9 +149,9 @@ def query_inside_with_scope(task_name, mod, dispatched): Returns ------- - result : IRModule + result : IRModule or None Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for - more general future use + more general future use. None is returned if there is no feedback hint. """ return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member task_name, From 40ee540cb4cd8f1867e824a6dae671f294c4fc43 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Mar 2022 15:04:08 +0900 Subject: [PATCH 19/19] remove TODO comment on parse_mod --- src/meta_schedule/integration.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 1ebd554d9323..f05e07e0f1c1 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -123,12 +123,6 @@ Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModu IRModule prim_mod = dispatched.value()[0]; ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; - // TODO(masahi): parse_mod below replaces the original function key with "main". - // This is necessary because, in practice, most uses of the scheduling primitive "get_block" - // assume that the function name is "main". - // If we do not have this requirement, we can remove the call to parse_mod and - // GetOnlyOneFunction* calls below and instead we can directly return sch->mod(). - // Keep the original func name to be returned later. GlobalVar gv = GetOnlyOneFunctionKey(prim_mod).value();