diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h index 9a8699b2fab9..56d8d379df93 100644 --- a/include/tvm/meta_schedule/integration.h +++ b/include/tvm/meta_schedule/integration.h @@ -86,14 +86,12 @@ class MetaScheduleContextNode : public runtime::Object { * \param target Target info * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to. * NullOpt means the dispatch needs to be done in the context. - * \return There are different types of the output - * 1) NullOpt if there is no feedback hint - * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc - * 3) relay::Function if `mod` should be dispatched to BYOC workflow - * 4) IRModule for unified dispatch + * \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it + * under IRModule for more general future use. NullOpt is returned + * if there is no feedback hint. */ - virtual Optional Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) = 0; + virtual Optional Query(runtime::String task_name, IRModule mod, Target target, + Optional> dispatched) = 0; static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext"; TVM_DECLARE_BASE_OBJECT_INFO(MetaScheduleContextNode, runtime::Object); @@ -123,15 +121,13 @@ class MetaScheduleContext : public runtime::ObjectRef { * \param mod The high-level IR * \param target Target info * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to - * \return There are different types of the output - * 1) NullOpt if there is no feedback hint - * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc - * 3) relay::Function if `mod` should be dispatched to BYOC workflow - * 4) IRModule for unified dispatch + * \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it + * under IRModule for more general future use. NullOpt is returned + * if there is no feedback hint */ - static Optional QueryInsideWithScope(runtime::String task_name, IRModule mod, - Target target, - Optional> dispatched); + static Optional QueryInsideWithScope(runtime::String task_name, IRModule mod, + Target target, + Optional> dispatched); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef, MetaScheduleContextNode); @@ -145,38 +141,6 @@ class MetaScheduleContext : public runtime::ObjectRef { void ExitWithScope(); }; -/**************** TaskExtraction ****************/ - -/*! - * \brief An integration context for task extraction - */ -class TaskExtractionNode : public MetaScheduleContextNode { - public: - /*! \brief The extracted tasks */ - Array tasks{nullptr}; - - void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); } - - // Inherited from base class - Optional Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) final; - - static constexpr const char* _type_key = "meta_schedule.TaskExtraction"; - TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, MetaScheduleContextNode); -}; - -/*! - * \brief Managed reference to TaskExtractionNode - * \sa TaskExtractionNode - */ -class TaskExtraction : public MetaScheduleContext { - public: - /*! \brief The path to a cache file storing extracted tasks */ - TaskExtraction(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, MetaScheduleContext, - TaskExtractionNode); -}; - /**************** ApplyHistoryBest ****************/ /*! @@ -193,8 +157,8 @@ class ApplyHistoryBestNode : public MetaScheduleContextNode { } // Inherited from base class - Optional Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) final; + Optional Query(runtime::String task_name, IRModule mod, Target target, + Optional> dispatched) final; static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, MetaScheduleContextNode); diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 26b01444e752..d9391d0d713f 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -15,17 +15,17 @@ # specific language governing permissions and limitations # under the License. """Meta schedule integration with high-level IR""" -from contextlib import contextmanager -from typing import Callable, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union -from tvm._ffi import register_object +import numpy as np # type: ignore +import tvm.runtime.ndarray as nd + +from tvm._ffi import register_object, get_global_func from tvm.ir import IRModule, transform from tvm.relay import Any from tvm.relay import Function as RelayFunc -from tvm.relay import vm from tvm.runtime import NDArray, Object from tvm.target import Target -from tvm.tir import PrimFunc from . import _ffi_api from .database import Database @@ -77,7 +77,7 @@ def query( mod: IRModule, target: Target, dispatched: Optional[List[IRModule]], - ) -> Union[IRModule, RelayFunc, PrimFunc, None]: + ) -> Union[IRModule, None]: """The entry point of the integration Parameters @@ -93,12 +93,9 @@ def query( Returns ------- - result : Union[IRModule, RelayFunc, PrimFunc, None] - There are different types of the output: - 1) NullOpt if there is no feedback hint; - 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc; - 3) relay::Function if `mod` should be dispatched to BYOC workflow; - 4) IRModule for unified dispatch + result : IRModule or None + Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for + more general future use. None is returned if there is no feedback hint. """ return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member self, @@ -126,7 +123,7 @@ def query_inside_with_scope( mod: IRModule, target: Target, dispatched: Optional[List[IRModule]], - ) -> Union[IRModule, RelayFunc, PrimFunc, None]: + ) -> Union[IRModule, None]: """The entry point of the integration workflow. The compilation process of the high-level IR should call this method for task extraction and for feedback hints @@ -137,7 +134,7 @@ def query_inside_with_scope( def query_inside_with_scope(task_name, mod, dispatched): ctx = MetaScheduleContext.current() assert ctx is not None - ctx.query(task_name, mod, target, dispatched) + mod = ctx.query(task_name, mod, target, dispatched) Parameters ---------- @@ -152,12 +149,9 @@ def query_inside_with_scope(task_name, mod, dispatched): Returns ------- - result : Union[IRModule, RelayFunc, PrimFunc, None] - There are different types of the output: - 1) NullOpt if there is no feedback hint; - 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc; - 3) relay::Function if `mod` should be dispatched to BYOC workflow; - 4) IRModule for unified dispatch + result : IRModule or None + Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for + more general future use. None is returned if there is no feedback hint. """ return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member task_name, @@ -176,17 +170,6 @@ def __exit__(self, ptype, value, trace) -> None: _ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member -@register_object("meta_schedule.TaskExtraction") -class TaskExtraction(MetaScheduleContext): - """An integration context for task extraction""" - - tasks: List[ExtractedTask] - """The extracted tasks""" - - def __init__(self) -> None: - self.__init_handle_by_constructor__(_ffi_api.TaskExtraction) # type: ignore # pylint: disable=no-member - - @register_object("meta_schedule.ApplyHistoryBest") class ApplyHistoryBest(MetaScheduleContext): """An integration context that allows application of historically best record from database""" @@ -230,45 +213,32 @@ def extract_task_from_relay( The tasks extracted from this network """ - @contextmanager - def _autotvm_silencer(): - from tvm import autotvm # pylint: disable=import-outside-toplevel - - silent = autotvm.GLOBAL_SCOPE.silent - autotvm.GLOBAL_SCOPE.silent = True - try: - yield - finally: - autotvm.GLOBAL_SCOPE.silent = silent + extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask") + assert extract_task_func - def _thread_run(func: Callable[[], None]) -> None: - import threading # pylint: disable=import-outside-toplevel + target = Target(target) if isinstance(target, str) else target - thread = threading.Thread(target=func) - thread.start() - thread.join() + relay_params = {} + for name, param in params.items(): + if isinstance(param, np.ndarray): + param = nd.array(param) + relay_params[name] = param if disabled_pass is None: disabled_pass = [] if pass_config is None: pass_config = {"relay.backend.use_meta_schedule": True} - env = TaskExtraction() if isinstance(mod, RelayFunc): mod = IRModule.from_expr(mod) if not isinstance(target, Target): target = Target(target) - def _func(): - with env, _autotvm_silencer(), transform.PassContext( - config=pass_config, - disabled_pass=disabled_pass, - opt_level=opt_level, - ): - compiler = vm.VMCompiler() - if params: - compiler.set_params(params) - compiler.lower(mod, target) - - _thread_run(_func) - return env.tasks + with target, transform.PassContext( + opt_level=opt_level, + config=pass_config, + disabled_pass=disabled_pass, + ): + tasks = extract_task_func(mod, target, relay_params) + # Tasks are extracted via post order visit, return the reversed list. + return list(reversed(tasks)) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index b446c1f0115c..55453df95e7a 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -47,6 +47,7 @@ def batch_matmul_vnni_compute(cfg, x, y): axis=ak, ), tag="batch_matmul_vnni", + attrs={"schedule_rule": "meta_schedule.batch_matmul_vnni"}, ) _, a_y, _ = z.op.axis diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index c8574b971003..1e4ccb7bb8c8 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -296,6 +296,7 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None): axis=ak, ), tag="dense_vnni", + attrs={"schedule_rule": "meta_schedule.dense_vnni"}, ) if bias is not None: diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 4f9055bf5bba..f05e07e0f1c1 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -21,14 +21,14 @@ #include #include "./utils.h" +#include "tvm/runtime/container/optional.h" namespace tvm { namespace meta_schedule { /**************** Utility functions ****************/ - -template -Optional GetOnlyOneFunction(const IRModule& mod) { +template +Optional GetOnlyOneFunctionCommon(const IRModule& mod, Callback on_found) { if (mod->functions.size() != 1) { return NullOpt; } @@ -37,12 +37,23 @@ Optional GetOnlyOneFunction(const IRModule& mod) { if (!func->IsInstance()) { return NullOpt; } else { - return Downcast(func); + return on_found(kv); } } return NullOpt; } +template +Optional GetOnlyOneFunctionKey(const IRModule& mod) { + return GetOnlyOneFunctionCommon(mod, [](auto kv) { return kv.first; }); +} + +template +Optional GetOnlyOneFunction(const IRModule& mod) { + return GetOnlyOneFunctionCommon( + mod, [](auto kv) { return Downcast(kv.second); }); +} + template bool HasOnlyOneFunction(const IRModule& mod) { return GetOnlyOneFunction(mod).defined(); @@ -86,33 +97,15 @@ void MetaScheduleContext::ExitWithScope() { ctx = NullOpt; } -Optional MetaScheduleContext::QueryInsideWithScope( - runtime::String task_name, IRModule mod, Target target, Optional> dispatched) { +Optional MetaScheduleContext::QueryInsideWithScope(runtime::String task_name, + IRModule mod, Target target, + Optional> dispatched) { if (Optional ctx = MetaScheduleContext::Current()) { return ctx.value()->Query(task_name, mod, target, dispatched); } return NullOpt; } -/**************** TaskExtraction ****************/ - -TaskExtraction::TaskExtraction() { - ObjectPtr n = make_object(); - n->tasks = Array(); - data_ = n; -} - -Optional TaskExtractionNode::Query(runtime::String task_name, IRModule mod, - Target target, Optional> dispatched) { - ICHECK(dispatched.defined()); - ICHECK_EQ(dispatched.value().size(), 1); - IRModule prim_mod = dispatched.value()[0]; - ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; - ICHECK(HasOnlyOneFunction(mod)) << mod; - tasks.push_back(ExtractedTask(task_name, mod, target, {prim_mod})); - return NullOpt; -} - /**************** ApplyHistoryBest ****************/ ApplyHistoryBest::ApplyHistoryBest(Database database) { @@ -121,18 +114,23 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) { data_ = n; } -Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, - Target target, - Optional> dispatched) { +Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, + Target target, + Optional> dispatched) { ICHECK(dispatched.defined()); ICHECK_EQ(dispatched.value().size(), 1); ICHECK(HasOnlyOneFunction(mod)) << mod; IRModule prim_mod = dispatched.value()[0]; ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + + // Keep the original func name to be returned later. + GlobalVar gv = GetOnlyOneFunctionKey(prim_mod).value(); + // Unify func name to make sure it can be found in database const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod"); ICHECK(parse_mod_func) << "Parse mod function not defined!"; prim_mod = (*parse_mod_func)(prim_mod); + if (database->HasWorkload(prim_mod)) { Array records = database->GetTopK(database->CommitWorkload(prim_mod), 1); if (records.size() == 1) { @@ -141,10 +139,12 @@ Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRMod /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); records[0]->trace->ApplyToSchedule(sch, false); tir::PrimFunc func = GetOnlyOneFunction(sch->mod()).value(); - return func; + // Make sure we return the updated PrimFunc paired with the original func name. + return IRModule({{gv, func}}); } } - LOG(WARNING) << "Cannot find workload: " << task_name << "\n" << tir::AsTVMScript(prim_mod); + LOG(WARNING) << "Cannot find workload: " << task_name; + DLOG(INFO) << tir::AsTVMScript(prim_mod); return NullOpt; } @@ -158,7 +158,6 @@ class MetaScheduleContextInternal { TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode); -TVM_REGISTER_NODE_TYPE(TaskExtractionNode); TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") @@ -176,9 +175,6 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQueryInsideWithScope") .set_body_typed(MetaScheduleContext::QueryInsideWithScope); TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") .set_body_method(&MetaScheduleContextNode::Query); -TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { - return TaskExtraction(); -}); TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") .set_body_typed([](Database database) -> ApplyHistoryBest { return ApplyHistoryBest(database); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 89ee61c83f7c..87fe39c389f0 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -333,14 +333,7 @@ class RelayBuildModule : public runtime::ModuleNode { IRModule OptimizeImpl(IRModule relay_module) { ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler."; - if (!params_.empty()) { - ICHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function"; - GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); - Function main_func = Downcast(relay_module->Lookup(main_glb_var)); - auto new_main = BindParamsByName(main_func, params_); - IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite(); - relay_module_ptr->Update(main_glb_var, new_main); - } + backend::BindParamsInModule(relay_module, params_); Array pass_seqs = GetPassPrefix( /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc new file mode 100644 index 000000000000..898e76b81b98 --- /dev/null +++ b/src/relay/backend/task_extraction.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../../te/operation/create_primfunc.h" +#include "te_compiler_cache.h" +#include "tvm/runtime/ndarray.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace backend { + +namespace metaschedule { + +using meta_schedule::ExtractedTask; + +Array ExtractTask(IRModule mod, Target target, + Map params) { + backend::BindParamsInModule(mod, params); + + // is_vm=true for backward compatibility + Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); + pass_seqs.push_back(transform::FuseOps()); + + transform::Sequential seq(pass_seqs); + auto opt_mod = seq(std::move(mod)); + + Array tasks; + std::unordered_set cache; + std::unordered_map name_map; + + PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache, &name_map](const Expr& exp) { + if (exp->IsInstance()) { + Function relay_func = Downcast(exp); + tec::CCacheKey cache_key(relay_func, target); + if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache.find(cache_key) == cache.end()) { + Array inputs_outputs; + std::string fused_name; + std::tie(inputs_outputs, fused_name) = + tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); + auto prim_func = tir::CreatePrimFunc(inputs_outputs); + GlobalVar prim_fn_var(fused_name); + IRModule relay_mod({{prim_fn_var, relay_func}}); + IRModule tir_mod({{prim_fn_var, prim_func}}); + auto task_name = tec::GetUniqueName(fused_name, &name_map); + tasks.push_back(ExtractedTask(task_name, relay_mod, target, {tir_mod})); + cache.insert(cache_key); + } + } + }); + + return tasks; +} + +} // namespace metaschedule + +TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask") + .set_body_typed([](IRModule mod, Target target, Map params) { + return metaschedule::ExtractTask(mod, target, params); + }); + +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index ffcce6e1c8da..8b8a1e92f82c 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -335,15 +335,13 @@ class ScheduleBuilder : public ExprVisitor { } } if (backend::IsMetaScheduleEnabled()) { - prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); - Optional opt_mod_or_base_func = - meta_schedule::MetaScheduleContext::QueryInsideWithScope( - prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, - Array{IRModule({{prim_fn_var, prim_func}})}); - if (const auto* result = opt_mod_or_base_func.as()) { - prim_func = GetRef(result); - } else { - prim_func = tir::PrimFunc(nullptr); + IRModule relay_mod({{prim_fn_var, relay_func}}); + IRModule tir_mod({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}}); + Optional scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope( + prim_fn_var->name_hint, relay_mod, target_, Array{tir_mod}); + if (scheduled_mod) { + ICHECK_EQ(scheduled_mod.value()->functions.count(prim_fn_var), 1); + prim_func = Downcast(scheduled_mod.value()->functions[prim_fn_var]); } } @@ -754,6 +752,25 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, return MakeShapeFunc().Create(prim_func, target, renamer); } +std::pair, std::string> LowerTECompute(const Function& source_func, Target target, + bool return_inputs) { + LowerToTECompute lower_te_compute(target); + Array outputs = + lower_te_compute.Lower(source_func, [](std::string name) { return name; }); + // Following ScheduleBuilder, remove placeholder ops from outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + if (return_inputs) { + return std::make_pair(Concat(lower_te_compute.fn_inputs_, tensor_outs), + lower_te_compute.candidate_name_); + } + return std::make_pair(tensor_outs, lower_te_compute.candidate_name_); +} + /*! * \brief Get unique name from name. * \param name The orginal name. diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 2ffca1aa6be7..55f221ac8ba0 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -37,6 +37,7 @@ #include #include #include +#include #include "../transforms/infer_layout_utils.h" @@ -204,6 +205,16 @@ class CCacheValue : public ObjectRef { Array GetShape(const Array& shape); +/*! + * \brief Lowers Relay primitive Function to TE Compute + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \param return_inputs If true, prepend input tensors to the output array of tensors. + * \return Pair of schedule and fused function name. + */ +std::pair, std::string> LowerTECompute(const Function& source_func, Target target, + bool return_inputs = true); + /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 7662018e4f71..fd3ab64fcc1c 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -29,6 +29,7 @@ #include #include "te_compiler.h" +#include "tvm/runtime/ndarray.h" namespace tvm { namespace relay { @@ -308,6 +309,56 @@ std::vector ShapeToJSON(tvm::Array shape) { return ret; } +relay::Function BindParamsByName(relay::Function func, + const std::unordered_map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto& name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(name_dict[name]); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map bind_dict; + for (auto& kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = Constant(kv.second); + } + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; +} + +void BindParamsInModule(IRModule mod, + const std::unordered_map& params) { + if (!params.empty()) { + BaseFunc base_func = mod->Lookup("main"); + ICHECK(base_func->IsInstance()); + auto f = relay::backend::BindParamsByName(Downcast(base_func), params); + auto gvar = mod->GetGlobalVar("main"); + mod->Add(gvar, f); + } +} + +void BindParamsInModule(IRModule mod, Map params) { + std::unordered_map params_tmp; + for (const auto& kv : params) { + params_tmp[kv.first] = kv.second; + } + BindParamsInModule(mod, params_tmp); +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 3b4d4c18de89..37a89d3edced 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -46,6 +46,7 @@ #include "../../runtime/meta_data.h" #include "../../target/metadata.h" +#include "tvm/runtime/ndarray.h" namespace tvm { namespace relay { @@ -386,36 +387,18 @@ inline std::string DType2String(const tvm::DataType dtype) { * \param params params dict * \return relay::Function */ -inline relay::Function BindParamsByName( - relay::Function func, const std::unordered_map& params) { - std::unordered_map name_dict; - std::unordered_set repeat_var; - for (auto arg : func->params) { - const auto& name = arg->name_hint(); - if (name_dict.count(name)) { - repeat_var.insert(name_dict[name]); - } else { - name_dict[name] = arg; - } - } +relay::Function BindParamsByName(relay::Function func, + const std::unordered_map& params); - std::unordered_map bind_dict; - for (auto& kv : params) { - if (name_dict.count(kv.first) == 0) { - continue; - } - auto arg = name_dict.at(kv.first); - if (repeat_var.count(arg)) { - LOG(FATAL) << "Multiple args in the function have name " << kv.first; - } - bind_dict[arg] = Constant(kv.second); - } - Expr bound_expr = relay::Bind(func, bind_dict); - Function ret = Downcast(bound_expr); - ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function." - << "\n"; - return ret; -} +/*! + * \brief Bind params to the main function in Relay module, using BindParamsByName + * \param mod Relay module + * \param params params dict + */ +void BindParamsInModule(IRModule mod, + const std::unordered_map& params); + +void BindParamsInModule(IRModule mod, Map params); /*! * \brief Extract the shape from a Relay tensor type. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e94919de7f20..130fb09e7af1 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1034,14 +1034,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets, IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { VLOG_CONTEXT << "VM Optimize"; - if (params_.size()) { - BaseFunc base_func = mod->Lookup("main"); - ICHECK(base_func->IsInstance()) - << "VM compiler expects to compile relay::Function"; - auto f = relay::backend::BindParamsByName(Downcast(base_func), params_); - auto gvar = mod->GetGlobalVar("main"); - mod->Add(gvar, f); - } + backend::BindParamsInModule(mod, params_); Array pass_seqs = relay::backend::GetPassPrefix( /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true); diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 50dc9289780d..4620e83d8ec4 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -26,7 +26,6 @@ ApplyHistoryBest, ExtractedTask, MetaScheduleContext, - TaskExtraction, ) from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.utils import derived_object @@ -63,61 +62,16 @@ def _has_torch(): requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") -def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule): - (task,) = tasks - assert isinstance(task, ExtractedTask) - assert task.task_name == "mock-task" - tvm.ir.assert_structural_equal(task.mod, mod) - (tir_mod,) = task.dispatched - tvm.ir.assert_structural_equal(tir_mod, MockModule) - - -@requires_torch -def test_meta_schedule_integration_task_extraction_query(): - mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - env = TaskExtraction() - env.query(task_name="mock-task", mod=mod, target=Target("llvm"), dispatched=[MockModule]) - _check_mock_task(env.tasks, mod) - - -def test_meta_schedule_integration_current(): - env = TaskExtraction() - with env: - assert MetaScheduleContext.current() == env - - def test_meta_schedule_integration_no_current(): assert MetaScheduleContext.current() is None -def test_meta_schedule_integration_multiple_current(): - env = TaskExtraction() - with env: - with pytest.raises(ValueError): - with env: - ... - - -@requires_torch -def test_meta_schedule_integration_query_inside_with_scope(): - mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - env = TaskExtraction() - with env: - MetaScheduleContext.query_inside_with_scope( - task_name="mock-task", - mod=mod, - target=Target("llvm"), - dispatched=[MockModule], - ) - _check_mock_task(env.tasks, mod) - - @requires_torch def test_meta_schedule_integration_extract_from_resnet(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) expected_task_names = [ - "vm_mod_fused_" + s + "fused_" + s for s in [ "nn_max_pool2d", "nn_adaptive_avg_pool2d", @@ -197,7 +151,6 @@ def print_results(self) -> None: TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []) ) mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule]) - mod = IRModule({"main": mod}) assert tvm.ir.structural_equal(mod, workload.mod)