From f8513cc0eb6e3cf09ada47082224725955a4ece0 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 3 Nov 2021 22:58:00 -0700 Subject: [PATCH 1/3] [MetaSchedule] Task Extraction Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Wuwei Lin --- include/tvm/meta_schedule/integration.h | 214 ++++++++++++++++ python/tvm/meta_schedule/__init__.py | 1 + python/tvm/meta_schedule/integration.py | 241 ++++++++++++++++++ python/tvm/meta_schedule/testing/__init__.py | 19 ++ .../{testing.py => testing/local_rpc.py} | 2 +- .../meta_schedule/testing/relay_workload.py | 88 +++++++ python/tvm/te/__init__.py | 2 +- python/tvm/te/operation.py | 26 +- src/meta_schedule/integration.cc | 151 +++++++++++ src/relay/backend/te_compiler.cc | 25 +- src/relay/backend/te_compiler_cache.cc | 49 +++- src/relay/backend/te_compiler_cache.h | 11 +- src/relay/backend/utils.h | 9 + src/te/operation/create_primfunc.cc | 38 ++- .../test_meta_schedule_integration.py | 120 +++++++++ 15 files changed, 961 insertions(+), 35 deletions(-) create mode 100644 include/tvm/meta_schedule/integration.h create mode 100644 python/tvm/meta_schedule/integration.py create mode 100644 python/tvm/meta_schedule/testing/__init__.py rename python/tvm/meta_schedule/{testing.py => testing/local_rpc.py} (97%) create mode 100644 python/tvm/meta_schedule/testing/relay_workload.py create mode 100644 src/meta_schedule/integration.cc create mode 100644 tests/python/unittest/test_meta_schedule_integration.py diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h new file mode 100644 index 000000000000..c6cb3a5fac28 --- /dev/null +++ b/include/tvm/meta_schedule/integration.h @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_INTEGRATION_H_ +#define TVM_META_SCHEDULE_INTEGRATION_H_ + +#include +#include + +#include + +namespace tvm { +namespace meta_schedule { + +/**************** ExtractedTask ****************/ + +/*! + * \brief A tuning task extracted from the high-level IR + */ +class ExtractedTaskNode : public runtime::Object { + public: + /*! \brief The name of the task extracted */ + String task_name; + /*! \brief The high-level IR */ + IRModule mod; + /*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */ + Array dispatched; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("task_name", &task_name); + v->Visit("mod", &mod); + v->Visit("dispatched", &dispatched); + } + + static constexpr const char* _type_key = "meta_schedule.ExtractedTask"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object); +}; + +/*! + * \brief Managed reference to ExtractedTaskNode + * \sa ExtractedTaskNode + */ +class ExtractedTask : public runtime::ObjectRef { + public: + /*! + * \brief Constructor. The name of the task extracted + * \brief The high-level IR + * \brief A list of low-level IRs that the high-level IR could potentially dispatch to + */ + explicit ExtractedTask(String task_name, IRModule mod, Array dispatched); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, ExtractedTaskNode); +}; + +/**************** MetaScheduleContext ****************/ + +/*! + * \brief A context manager interface for the integration + */ +class MetaScheduleContextNode : public runtime::Object { + public: + /*! \brief Default destructor */ + virtual ~MetaScheduleContextNode() = default; + /*! + * \brief The entry point of the integration + * \param task_name The name of the task + * \param mod The high-level IR + * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to. + * NullOpt means the dispatch needs to be done in the context. + * \return There are different types of the output + * 1) NullOpt if there is no feedback hint + * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc + * 3) relay::Function if `mod` should be dispatched to BYOC workflow + * 4) IRModule for unified dispatch + */ + virtual Optional Query(runtime::String task_name, IRModule mod, + Optional> dispatched) = 0; + + static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext"; + TVM_DECLARE_BASE_OBJECT_INFO(MetaScheduleContextNode, runtime::Object); +}; + +/*! + * \brief Managed reference to MetaScheduleContextNode + * \sa MetaScheduleContextNode + */ +class MetaScheduleContext : public runtime::ObjectRef { + friend class MetaScheduleContextInternal; + friend class With; + + public: + /*! \brief Default destructor */ + virtual ~MetaScheduleContext() = default; + /*! + * \brief The context manager in the current scope + * \return The MetaScheduleContext in the current scope. NullOpt if it's currently not under any + * MetaScheduleContext. + */ + static Optional Current(); + /*! + * \brief The entry point of the integration workflow. The compilation process of the high-level + * IR should call this method for task extraction and for feedback hints + * \param task_name The name of the task + * \param mod The high-level IR + * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to + * \return There are different types of the output + * 1) NullOpt if there is no feedback hint + * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc + * 3) relay::Function if `mod` should be dispatched to BYOC workflow + * 4) IRModule for unified dispatch + */ + static Optional QueryInWithScope(runtime::String task_name, IRModule mod, + Optional> dispatched); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef, + MetaScheduleContextNode); + + protected: + /*! \brief Default constructor */ + MetaScheduleContext() = default; + /*! \brief Entering the scope of the context manager */ + void EnterWithScope(); + /*! \brief Exiting the scope of the context manager */ + void ExitWithScope(); +}; + +/**************** TaskExtraction ****************/ + +/*! + * \brief An integration context for task extraction + */ +class TaskExtractionNode : public MetaScheduleContextNode { + public: + /*! \brief The extracted tasks */ + Array tasks{nullptr}; + + void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); } + + // Inherited from base class + Optional Query(runtime::String task_name, IRModule mod, + Optional> dispatched) final; + + static constexpr const char* _type_key = "meta_schedule.TaskExtraction"; + TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, MetaScheduleContextNode); +}; + +/*! + * \brief Managed reference to TaskExtractionNode + * \sa TaskExtractionNode + */ +class TaskExtraction : public MetaScheduleContext { + public: + /*! \brief The path to a cache file storing extracted tasks */ + TaskExtraction(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, MetaScheduleContext, + TaskExtractionNode); +}; + +/**************** ApplyHistoryBest ****************/ + +/*! + * \brief An integration context that allows application of historically best records from a + * database + */ +class ApplyHistoryBestNode : public MetaScheduleContextNode { + public: + /*! \brief The database to be queried from */ + Database database{nullptr}; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("database", &database); // + } + + // Inherited from base class + Optional Query(runtime::String task_name, IRModule mod, + Optional> dispatched) final; + + static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; + TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, MetaScheduleContextNode); +}; + +/*! + * \brief Managed reference to ApplyHistoryBestNode + * \sa ApplyHistoryBestNode + */ +class ApplyHistoryBest : public MetaScheduleContext { + public: + /*! + * \brief Constructor + * \param database The database to be queried from + */ + explicit ApplyHistoryBest(Database database); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, MetaScheduleContext, + ApplyHistoryBestNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_INTEGRATION_H_ diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 2e280ef20ac3..47b3dda5a36e 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -21,4 +21,5 @@ from . import runner from . import space_generator from . import search_strategy +from . import integration from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py new file mode 100644 index 000000000000..d52b30387526 --- /dev/null +++ b/python/tvm/meta_schedule/integration.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta schedule integration with high-level IR""" +from contextlib import contextmanager +from typing import Callable, Dict, List, Optional, Union + +from tvm._ffi import register_object +from tvm.ir import IRModule, transform +from tvm.relay import Any, Function as RelayFunc, vm +from tvm.runtime import NDArray, Object +from tvm.target import Target +from tvm.tir import PrimFunc + +from . import _ffi_api + + +@register_object("meta_schedule.ExtractedTask") +class ExtractedTask(Object): + """A tuning task extracted from the high-level IR + + Parameters + ---------- + task_name : str + The name of the task extracted + mod : IRModule + The high-level IR + dispatched : List[IRModule] + A list of low-level IRs that the high-level IR could potentially dispatch to + """ + + task_name: str + mod: IRModule + dispatched: List[IRModule] + + def __init__( + self, + task_name: str, + mod: IRModule, + dispatched: List[IRModule], + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member + task_name, + mod, + dispatched, + ) + + +@register_object("meta_schedule.MetaScheduleContext") +class MetaScheduleContext(Object): + """A context manager interface for the integration""" + + def query( + self, + task_name: str, + mod: IRModule, + dispatched: Optional[List[IRModule]], + ) -> Union[IRModule, RelayFunc, PrimFunc, None]: + """The entry point of the integration + + Parameters + ---------- + task_name : str + The name of the task extracted + mod : IRModule + The high-level IR + dispatched : Optional[List[IRModule]] + A list of low-level IRs that the high-level IR could potentially dispatch to + + Returns + ------- + result : Union[IRModule, RelayFunc, PrimFunc, None] + There are different types of the output: + 1) NullOpt if there is no feedback hint; + 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc; + 3) relay::Function if `mod` should be dispatched to BYOC workflow; + 4) IRModule for unified dispatch + """ + return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member + self, + task_name, + mod, + dispatched, + ) + + @staticmethod + def current() -> Optional["MetaScheduleContext"]: + """The context manager in the current scope + + Returns + ------- + ctx : Optional[MetaScheduleContext] + The MetaScheduleContext in the current scope. + NullOpt if it's currently not under any MetaScheduleContext. + """ + return _ffi_api.MetaScheduleContextCurrent() # type: ignore # pylint: disable=no-member + + @staticmethod + def query_in_with_scope( + task_name: str, + mod: IRModule, + dispatched: Optional[List[IRModule]], + ) -> Union[IRModule, RelayFunc, PrimFunc, None]: + """The entry point of the integration workflow. The compilation process of the high-level + IR should call this method for task extraction and for feedback hints + + Parameters + ---------- + task_name : str + The name of the task + mod : IRModule + The high-level IR + dispatched : Optional[List[IRModule]] + A list of low-level IRs that the high-level IR could potentially dispatch to + + Returns + ------- + result : Union[IRModule, RelayFunc, PrimFunc, None] + There are different types of the output: + 1) NullOpt if there is no feedback hint; + 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc; + 3) relay::Function if `mod` should be dispatched to BYOC workflow; + 4) IRModule for unified dispatch + """ + return _ffi_api.MetaScheduleContextQueryInWithScope( # type: ignore # pylint: disable=no-member + task_name, + mod, + dispatched, + ) + + def __enter__(self) -> "MetaScheduleContext": + """Entering the scope of the context manager""" + _ffi_api.MetaScheduleContextEnterScope(self) # type: ignore # pylint: disable=no-member + return self + + def __exit__(self, ptype, value, trace) -> None: + """Exiting the scope of the context manager""" + _ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.TaskExtraction") +class TaskExtraction(MetaScheduleContext): + """An integration context for task extraction""" + + tasks: List[ExtractedTask] + """The extracted tasks""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__(_ffi_api.TaskExtraction) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.ApplyHistoryBest") +class ApplyHistoryBest(MetaScheduleContext): + pass + + +def extract_task( + mod: Union[IRModule, RelayFunc], + target: Target, + params: Optional[Dict[str, NDArray]] = None, + *, + opt_level: int = 3, + pass_config: Dict[str, Any] = { + "relay.backend.use_meta_schedule": True, + }, + disabled_pass: List[str] = [], +) -> List[ExtractedTask]: + """Extract tuning tasks from a relay program. + + Parameters + ---------- + mod : Union[tvm.IRModule, tvm.relay.Function] + The module or function to tune + target : tvm.target.Target + The compilation target + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + opt_level : int + The optimization level of the compiler + pass_config : Dict[str, Any] + The pass config of the compiler + disabled_pass : List[str] + The list of disabled passes of the compiler + + Returns + ------- + tasks: List[ExtractedTask] + The tasks extracted from this network + """ + + @contextmanager + def _autotvm_silencer(): + from tvm import autotvm # pylint: disable=import-outside-toplevel + + silent = autotvm.GLOBAL_SCOPE.silent + autotvm.GLOBAL_SCOPE.silent = True + try: + yield + finally: + autotvm.GLOBAL_SCOPE.silent = silent + + def _thread_run(func: Callable[[], None]) -> None: + import threading # pylint: disable=import-outside-toplevel + + thread = threading.Thread(target=func) + thread.start() + thread.join() + + env = TaskExtraction() + if isinstance(mod, RelayFunc): + mod = IRModule.from_expr(mod) + if not isinstance(target, Target): + target = Target(target) + + def _func(): + with env, _autotvm_silencer(), transform.PassContext( + config=pass_config, + disabled_pass=disabled_pass, + opt_level=opt_level, + ): + compiler = vm.VMCompiler() + if params: + compiler.set_params(params) + compiler.lower(mod, target) + + _thread_run(_func) + return env.tasks diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py new file mode 100644 index 000000000000..7e516a510f66 --- /dev/null +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities in meta schedule""" +from .local_rpc import LocalRPC +from .relay_workload import get_network diff --git a/python/tvm/meta_schedule/testing.py b/python/tvm/meta_schedule/testing/local_rpc.py similarity index 97% rename from python/tvm/meta_schedule/testing.py rename to python/tvm/meta_schedule/testing/local_rpc.py index b286e3b18a93..cd1221124cc9 100644 --- a/python/tvm/meta_schedule/testing.py +++ b/python/tvm/meta_schedule/testing/local_rpc.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Testing utilities in meta schedule""" +"""RPC tracker and server running locally""" from tvm.rpc.tracker import Tracker from tvm.rpc.server import Server diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py new file mode 100644 index 000000000000..1eb9950f7fc7 --- /dev/null +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Workloads in Relay IR""" +from typing import Dict, Tuple + +import tvm.relay.testing # pylint: disable=unused-import +from tvm import relay +from tvm.ir import IRModule +from tvm.runtime import NDArray + + +def get_network( + name: str, + batch_size: int, + layout: str = "NHWC", + dtype: str = "float32", +) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int, int]]: + """Get the symbol definition and random weight of a network""" + # meta-schedule prefers NHWC layout + if layout == "NHWC": + image_shape = (224, 224, 3) + elif layout == "NCHW": + image_shape = (3, 224, 224) + else: + raise ValueError("Invalid layout: " + layout) + + input_shape: Tuple[int, int, int, int] = (batch_size,) + image_shape + output_shape: Tuple[int, int] = (batch_size, 1000) + + if name.startswith("resnet-"): + n_layer = int(name.split("-")[1]) + mod, params = relay.testing.resnet.get_workload( + num_layers=n_layer, + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name.startswith("resnet3d-"): + n_layer = int(name.split("-")[1]) + mod, params = relay.testing.resnet.get_workload( + num_layers=n_layer, + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name == "mobilenet": + mod, params = relay.testing.mobilenet.get_workload( + batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape + ) + elif name == "squeezenet_v1.1": + assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout" + mod, params = relay.testing.squeezenet.get_workload( + version="1.1", + batch_size=batch_size, + dtype=dtype, + image_shape=image_shape, + ) + elif name == "inception_v3": + input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3) + mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) + elif name == "mxnet": + from mxnet.gluon.model_zoo.vision import get_model # type: ignore # pylint: disable=import-outside-toplevel + + assert layout == "NCHW" + block = get_model("resnet50_v1", pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype) + net = mod["main"] + net = relay.Function( + net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs + ) + mod = IRModule.from_expr(net) + return mod, params, input_shape, output_shape diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 250c165caf9a..308257085e51 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -33,7 +33,7 @@ from .tag import tag_scope from .operation import placeholder, compute, scan, extern, var, size_var from .operation import thread_axis, reduce_axis -from .operation import create_prim_func +from .operation import create_prim_func, create_prim_func_from_outputs from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp from .autodiff import gradient diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index cb0305d49e4a..5cb58a85ed10 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -17,14 +17,14 @@ """ Operation class for computation declaration.""" # pylint: disable=invalid-name from numbers import Integral as _Integral -from typing import List +from typing import List, Union import tvm._ffi -import tvm.tir -import tvm.tir._ffi_api from tvm._ffi.base import string_types from tvm.ir import Array from tvm.runtime import convert +import tvm.tir +import tvm.tir._ffi_api from . import _ffi_api from . import tag as _tag @@ -482,3 +482,23 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: if not isinstance(ops, (list, tuple, Array)): ops = [ops] return _ffi_api.CreatePrimFunc(ops) + + +def create_prim_func_from_outputs( + outputs: Union[_tensor.Tensor, List[_tensor.Tensor]], +) -> tvm.tir.PrimFunc: + """Create a TensorIR PrimFunc from output tensor(s) in TE + + Parameters + ---------- + outputs : Union[Tensor, List[Tensor]] + The source expression. + + Returns + ------- + func : tir.PrimFunc + The created function. + """ + if not isinstance(outputs, (list, tuple, Array)): + outputs = [outputs] + return _ffi_api.CreatePrimFuncFromOutputs(outputs) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc new file mode 100644 index 000000000000..5dc3d795a2ef --- /dev/null +++ b/src/meta_schedule/integration.cc @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +/**************** Utility functions ****************/ + +template +bool HasOnlyOneFunction(const IRModule& mod) { + if (mod->functions.size() != 1) { + return false; + } + for (const auto& kv : mod->functions) { + const BaseFunc& func = kv.second; + if (!func->IsInstance()) { + return false; + } + } + return true; +} + +/**************** ExtractedTask ****************/ + +ExtractedTask::ExtractedTask(String task_name, IRModule mod, Array dispatched) { + ObjectPtr n = make_object(); + n->task_name = task_name; + n->mod = mod; + n->dispatched = dispatched; + data_ = n; +} + +/**************** MetaScheduleContext ****************/ + +struct MetaScheduleContextThreadLocalEntry { + Optional ctx; +}; + +using MetaScheduleContextThreadLocalStore = + dmlc::ThreadLocalStore; + +Optional MetaScheduleContext::Current() { + return MetaScheduleContextThreadLocalStore::Get()->ctx; +} + +void MetaScheduleContext::EnterWithScope() { + Optional& ctx = MetaScheduleContextThreadLocalStore::Get()->ctx; + CHECK(!ctx.defined()) + << "ValueError: Nested MetaScheduleContext context managers are not allowed"; + ctx = *this; +} + +void MetaScheduleContext::ExitWithScope() { + Optional& ctx = MetaScheduleContextThreadLocalStore::Get()->ctx; + ICHECK(ctx.defined()); + ctx = NullOpt; +} + +Optional MetaScheduleContext::QueryInWithScope(runtime::String task_name, IRModule mod, + Optional> dispatched) { + if (Optional ctx = MetaScheduleContext::Current()) { + return ctx.value()->Query(task_name, mod, dispatched); + } + return NullOpt; +} + +/**************** TaskExtraction ****************/ + +TaskExtraction::TaskExtraction() { + ObjectPtr n = make_object(); + n->tasks = Array(); + data_ = n; +} + +Optional TaskExtractionNode::Query(runtime::String task_name, IRModule mod, + Optional> dispatched) { + ICHECK(dispatched.defined()); + ICHECK_EQ(dispatched.value().size(), 1); + IRModule prim_mod = dispatched.value()[0]; + ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + ICHECK(HasOnlyOneFunction(mod)) << mod; + tasks.push_back(ExtractedTask(task_name, mod, {prim_mod})); + return NullOpt; +} + +/**************** ApplyHistoryBest ****************/ + +ApplyHistoryBest::ApplyHistoryBest(Database database) { + ObjectPtr n = make_object(); + n->database = database; + data_ = n; +} + +Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, + Optional> dispatched) { + throw; +} + +/**************** FFI ****************/ + +class MetaScheduleContextInternal { + public: + static void EnterScope(MetaScheduleContext ctx) { ctx.EnterWithScope(); } + static void ExitScope(MetaScheduleContext ctx) { ctx.ExitWithScope(); } +}; + +TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); +TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode); +TVM_REGISTER_NODE_TYPE(TaskExtractionNode); +TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); + +TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") + .set_body_typed([](String task_name, IRModule mod, + Array dispatched) -> ExtractedTask { + return ExtractedTask(task_name, mod, dispatched); + }); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextEnterScope") + .set_body_typed(MetaScheduleContextInternal::EnterScope); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextExitScope") + .set_body_typed(MetaScheduleContextInternal::ExitScope); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextCurrent") + .set_body_typed(MetaScheduleContext::Current); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQueryInWithScope") + .set_body_typed(MetaScheduleContext::QueryInWithScope); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") + .set_body_method(&MetaScheduleContextNode::Query); +TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { + return TaskExtraction(); +}); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index e1ed3d47d36d..b284fc8bc0ca 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -222,7 +222,8 @@ class TECompilerImpl : public TECompilerNode { auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule{nullptr}, + tir::PrimFunc{nullptr}, {}, ir_module); return value; } @@ -243,16 +244,19 @@ class TECompilerImpl : public TECompilerNode { return value; } } - - // NOTE: array will copy on write. - Array all_args = Array(cfunc->inputs); - for (te::Tensor arg : cfunc->outputs) { - all_args.push_back(arg); + if (cfunc->prim_func.defined()) { + cfunc->funcs->Update(cfunc->prim_fn_var, cfunc->prim_func.value()); + } else { + // NOTE: array will copy on write. + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { + all_args.push_back(arg); + } + // lower the function + std::unordered_map binds; + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); } - - std::unordered_map binds; - auto func_name = cfunc->prim_fn_var->name_hint; - cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); value->cached_func = cfunc; return value; } @@ -319,6 +323,7 @@ TECompiler& TECompiler::Global() { return *inst; } TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool); TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() { return TECompiler::Global(); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 3970b0e806f0..4988d42d9887 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -70,7 +70,8 @@ CCacheKey::CCacheKey(Function source_func, Target target) { CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array inputs, tvm::Array outputs, te::Schedule schedule, - tvm::Array shape_func_param_states, IRModule funcs) { + tir::PrimFunc prim_func, tvm::Array shape_func_param_states, + IRModule funcs) { auto n = make_object(); n->target = target; n->prim_fn_var = prim_fn_var; @@ -117,11 +118,12 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator create_schedule_(create_schedule) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + use_meta_schedule_ = backend::IsMetaScheduleEnabled(); } - CachedFunc Create(const Function& prim_func, std::function renamer) { + CachedFunc Create(const Function& relay_func, std::function renamer) { Array fn_inputs; - for (Var param : prim_func->params) { + for (Var param : relay_func->params) { Array inputs; for (const auto& ttype : FlattenTupleType(param->checked_type())) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); @@ -131,7 +133,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator memo_[param] = inputs; } readable_name_stream_ << "fused"; - auto outputs = this->VisitExpr(prim_func->body); + auto outputs = this->VisitExpr(relay_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME @@ -151,7 +153,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator prim_fn_name = renamer(prim_fn_name); } auto prim_fn_var = GlobalVar(prim_fn_name); - prim_fn_var->checked_type_ = prim_func->checked_type(); + prim_fn_var->checked_type_ = relay_func->checked_type(); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. @@ -163,7 +165,8 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } } - te::Schedule schedule; + te::Schedule schedule{nullptr}; + tir::PrimFunc prim_func{nullptr}; // No need to register schedule for device copy op. if (anchor_attrs_.as() == nullptr && create_schedule_) { if (use_auto_scheduler_) { @@ -176,20 +179,39 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator schedule = Downcast(obj); } } + if (use_meta_schedule_) { + const auto* f_create_func = runtime::Registry::Get("te.CreatePrimFuncFromOutputs"); + const auto* f_meta_schedule = + runtime::Registry::Get("meta_schedule.MetaScheduleContextQueryInWithScope"); + ICHECK(f_create_func) << "te.CreatePrimFuncFromOutputs is not registered"; + ICHECK(f_meta_schedule) + << "meta_schedule.MetaScheduleContextQueryInWithScope is not registered"; + prim_func = (*f_create_func)(tensor_outs); + Optional opt_mod_or_base_func = + (*f_meta_schedule)(prim_fn_name, IRModule({{GlobalVar(prim_fn_name), relay_func}}), + Array{IRModule({{GlobalVar(prim_fn_name), prim_func}})}); + if (const auto* result = opt_mod_or_base_func.as()) { + prim_func = GetRef(result); + } else { + prim_func = tir::PrimFunc(nullptr); + } + } // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined()) { + if (!schedule.defined() && !prim_func.defined()) { ICHECK(anchor_implementation_.defined()); schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); } - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); + if (schedule.defined()) { + for (const auto& scalar : scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } } } } - return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {}); + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}); } Array VisitExpr_(const VarNode* op) final { @@ -336,6 +358,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator std::ostringstream readable_name_stream_; Array scalars_; bool use_auto_scheduler_; + bool use_meta_schedule_; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; @@ -450,8 +473,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> std::unordered_map binds; IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, binds); - return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, shape_func_param_states, - ir_module); + return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, tir::PrimFunc{nullptr}, + shape_func_param_states, ir_module); } Array VisitExpr(const Expr& expr) final { diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 7975ef873173..2171880fd6a5 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -129,16 +129,18 @@ class CCacheKey : public ObjectRef { /*! \brief Node container to represent a cached function. */ struct CachedFuncNode : public Object { - /* \brief compiled target */ + /*! \brief compiled target */ tvm::Target target; /*! \brief Primitive Function Name */ GlobalVar prim_fn_var; - /* \brief The inputs to the function */ + /*! \brief The inputs to the function */ tvm::Array inputs; - /* \brief The outputs to the function */ + /*! \brief The outputs to the function */ tvm::Array outputs; /*! \brief The schedule to the function */ te::Schedule schedule; + /*! \brief The TIR function if lowering in the meta schedule path */ + Optional prim_func; /*! \brief Parameter usage states in the shape function. */ tvm::Array shape_func_param_states; /*! \brief The lowered functions to support the function. */ @@ -150,6 +152,7 @@ struct CachedFuncNode : public Object { v->Visit("inputs", &inputs); v->Visit("outputs", &outputs); v->Visit("schedule", &schedule); + v->Visit("prim_func", &prim_func); v->Visit("funcs", &funcs); v->Visit("shape_func_param_states", &shape_func_param_states); } @@ -161,7 +164,7 @@ struct CachedFuncNode : public Object { class CachedFunc : public ObjectRef { public: CachedFunc(tvm::Target target, GlobalVar prim_fn_name, tvm::Array inputs, - tvm::Array outputs, te::Schedule schedule, + tvm::Array outputs, te::Schedule schedule, tir::PrimFunc prim_func, tvm::Array shape_func_param_states, IRModule funcs = IRModule(Map({}))); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index f89a099b0d4f..16cbe0e8dbca 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -427,6 +427,15 @@ inline bool IsAutoSchedulerEnabled() { .value(); } +/*! + * \brief Return whether the meta schedule is enabled in the pass context. + */ +inline bool IsMetaScheduleEnabled() { + return transform::PassContext::Current() + ->GetConfig("relay.backend.use_meta_schedule", Bool(false)) + .value(); +} + /*! * \brief Get the sequence of Relay optimization passes based on backend type. * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 657dc121961c..d90681a1c0db 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -22,6 +22,7 @@ #include #include +#include #include "../schedule/graph.h" @@ -300,9 +301,40 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { return (*complete)(func, info.root_alloc); } // namespace tir -TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed([](const Array& tensors) { - return CreatePrimFunc(tensors); -}); +PrimFunc CreatePrimFuncFromOutputs(const Array& outputs) { + std::vector stack; + std::unordered_set visited; + for (const te::Tensor& output : outputs) { + if (!visited.count(output.get())) { + visited.insert(output.get()); + stack.push_back(output); + } + } + + Array arg_list; + while (!stack.empty()) { + te::Tensor tensor = stack.back(); + stack.pop_back(); + if (tensor->op->IsInstance()) { + arg_list.push_back(tensor); + } else if (tensor->op->IsInstance()) { + Array inputs = tensor->op->InputTensors(); + for (const te::Tensor& input : inputs) { + if (!visited.count(input.get())) { + visited.insert(input.get()); + stack.push_back(input); + } + } + } + } + for (const te::Tensor& output : outputs) { + arg_list.push_back(output); + } + return CreatePrimFunc(arg_list); +} + +TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); +TVM_REGISTER_GLOBAL("te.CreatePrimFuncFromOutputs").set_body_typed(CreatePrimFuncFromOutputs); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py new file mode 100644 index 000000000000..794e1cac56b7 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +from typing import List + +import pytest + +import tvm +from tvm import meta_schedule as ms +from tvm.ir.module import IRModule +from tvm.meta_schedule.integration import ( + ExtractedTask, + MetaScheduleContext, + TaskExtraction, +) +from tvm.meta_schedule.testing import get_network +from tvm.script import tir as T + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking + + +@tvm.script.ir_module +class MockModule: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + for i in T.serial(0, 16): + with T.block("matmul"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking + + +def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule): + (task,) = tasks + assert isinstance(task, ExtractedTask) + assert task.task_name == "mock-task" + tvm.ir.assert_structural_equal(task.mod, mod) + (tir_mod,) = task.dispatched + tvm.ir.assert_structural_equal(tir_mod, MockModule) + + +def test_meta_schedule_integration_task_extraction_query(): + mod, _, _, _ = get_network( + name="resnet-18", + batch_size=1, + layout="NHWC", + dtype="float32", + ) + env = TaskExtraction() + env.query(task_name="mock-task", mod=mod, dispatched=[MockModule]) + _check_mock_task(env.tasks, mod) + + +def test_meta_schedule_integration_current(): + env = TaskExtraction() + with env: + assert MetaScheduleContext.current() == env + + +def test_meta_schedule_integration_no_current(): + assert MetaScheduleContext.current() is None + + +def test_meta_schedule_integration_multiple_current(): + env = TaskExtraction() + with env: + with pytest.raises(ValueError): + with env: + ... + + +def test_meta_schedule_integration_query_in_with_scope(): + mod, _, _, _ = get_network( + name="resnet-18", + batch_size=1, + layout="NHWC", + dtype="float32", + ) + env = TaskExtraction() + with env: + MetaScheduleContext.query_in_with_scope( + task_name="mock-task", + mod=mod, + dispatched=[MockModule], + ) + _check_mock_task(env.tasks, mod) + + +def test_meta_schedule_integration_extract_from_resnet(): + mod, params, _, _ = get_network( + name="resnet-18", + batch_size=1, + layout="NHWC", + dtype="float32", + ) + extracted_tasks = ms.integration.extract_task(mod, target="llvm", params=params) + assert len(extracted_tasks) == 30 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 6c1cf9c708d5e254c15304c11767281c1885b690 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 4 Nov 2021 10:33:46 -0700 Subject: [PATCH 2/3] renaming and docs --- include/tvm/meta_schedule/integration.h | 4 ++-- python/tvm/meta_schedule/integration.py | 13 +++++++++++-- src/meta_schedule/integration.cc | 8 ++++---- src/relay/backend/te_compiler_cache.cc | 4 ++-- .../unittest/test_meta_schedule_integration.py | 4 ++-- 5 files changed, 21 insertions(+), 12 deletions(-) diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h index c6cb3a5fac28..ee35505d65b7 100644 --- a/include/tvm/meta_schedule/integration.h +++ b/include/tvm/meta_schedule/integration.h @@ -123,8 +123,8 @@ class MetaScheduleContext : public runtime::ObjectRef { * 3) relay::Function if `mod` should be dispatched to BYOC workflow * 4) IRModule for unified dispatch */ - static Optional QueryInWithScope(runtime::String task_name, IRModule mod, - Optional> dispatched); + static Optional QueryInsideWithScope(runtime::String task_name, IRModule mod, + Optional> dispatched); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef, MetaScheduleContextNode); diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index d52b30387526..47003c6faa25 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -110,7 +110,7 @@ def current() -> Optional["MetaScheduleContext"]: return _ffi_api.MetaScheduleContextCurrent() # type: ignore # pylint: disable=no-member @staticmethod - def query_in_with_scope( + def query_inside_with_scope( task_name: str, mod: IRModule, dispatched: Optional[List[IRModule]], @@ -118,6 +118,15 @@ def query_in_with_scope( """The entry point of the integration workflow. The compilation process of the high-level IR should call this method for task extraction and for feedback hints + Basically, this method is equivalent to: + + .. code-block:: python + + def query_inside_with_scope(task_name, mod, dispatched): + ctx = MetaScheduleContext.current() + assert ctx is not None + ctx.query(task_name, mod, dispatched) + Parameters ---------- task_name : str @@ -136,7 +145,7 @@ def query_in_with_scope( 3) relay::Function if `mod` should be dispatched to BYOC workflow; 4) IRModule for unified dispatch """ - return _ffi_api.MetaScheduleContextQueryInWithScope( # type: ignore # pylint: disable=no-member + return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member task_name, mod, dispatched, diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 5dc3d795a2ef..cf4262814947 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -75,8 +75,8 @@ void MetaScheduleContext::ExitWithScope() { ctx = NullOpt; } -Optional MetaScheduleContext::QueryInWithScope(runtime::String task_name, IRModule mod, - Optional> dispatched) { +Optional MetaScheduleContext::QueryInsideWithScope( + runtime::String task_name, IRModule mod, Optional> dispatched) { if (Optional ctx = MetaScheduleContext::Current()) { return ctx.value()->Query(task_name, mod, dispatched); } @@ -139,8 +139,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextExitScope") .set_body_typed(MetaScheduleContextInternal::ExitScope); TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextCurrent") .set_body_typed(MetaScheduleContext::Current); -TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQueryInWithScope") - .set_body_typed(MetaScheduleContext::QueryInWithScope); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQueryInsideWithScope") + .set_body_typed(MetaScheduleContext::QueryInsideWithScope); TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") .set_body_method(&MetaScheduleContextNode::Query); TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 4988d42d9887..266bd719545a 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -182,10 +182,10 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator if (use_meta_schedule_) { const auto* f_create_func = runtime::Registry::Get("te.CreatePrimFuncFromOutputs"); const auto* f_meta_schedule = - runtime::Registry::Get("meta_schedule.MetaScheduleContextQueryInWithScope"); + runtime::Registry::Get("meta_schedule.MetaScheduleContextQueryInsideWithScope"); ICHECK(f_create_func) << "te.CreatePrimFuncFromOutputs is not registered"; ICHECK(f_meta_schedule) - << "meta_schedule.MetaScheduleContextQueryInWithScope is not registered"; + << "meta_schedule.MetaScheduleContextQueryInsideWithScope is not registered"; prim_func = (*f_create_func)(tensor_outs); Optional opt_mod_or_base_func = (*f_meta_schedule)(prim_fn_name, IRModule({{GlobalVar(prim_fn_name), relay_func}}), diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 794e1cac56b7..f508c7d252e1 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -88,7 +88,7 @@ def test_meta_schedule_integration_multiple_current(): ... -def test_meta_schedule_integration_query_in_with_scope(): +def test_meta_schedule_integration_query_inside_with_scope(): mod, _, _, _ = get_network( name="resnet-18", batch_size=1, @@ -97,7 +97,7 @@ def test_meta_schedule_integration_query_in_with_scope(): ) env = TaskExtraction() with env: - MetaScheduleContext.query_in_with_scope( + MetaScheduleContext.query_inside_with_scope( task_name="mock-task", mod=mod, dispatched=[MockModule], From 9b796d4e4826bf9bacf0b31bde5b042cf5f61cb8 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 4 Nov 2021 13:53:27 -0700 Subject: [PATCH 3/3] Update integration.h --- include/tvm/meta_schedule/integration.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h index ee35505d65b7..c6925eed91c4 100644 --- a/include/tvm/meta_schedule/integration.h +++ b/include/tvm/meta_schedule/integration.h @@ -80,12 +80,12 @@ class MetaScheduleContextNode : public runtime::Object { * \param task_name The name of the task * \param mod The high-level IR * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to. - * NullOpt means the dispatch needs to be done in the context. + * NullOpt means the dispatch needs to be done in the context. * \return There are different types of the output - * 1) NullOpt if there is no feedback hint - * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc - * 3) relay::Function if `mod` should be dispatched to BYOC workflow - * 4) IRModule for unified dispatch + * 1) NullOpt if there is no feedback hint + * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc + * 3) relay::Function if `mod` should be dispatched to BYOC workflow + * 4) IRModule for unified dispatch */ virtual Optional Query(runtime::String task_name, IRModule mod, Optional> dispatched) = 0; @@ -118,10 +118,10 @@ class MetaScheduleContext : public runtime::ObjectRef { * \param mod The high-level IR * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to * \return There are different types of the output - * 1) NullOpt if there is no feedback hint - * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc - * 3) relay::Function if `mod` should be dispatched to BYOC workflow - * 4) IRModule for unified dispatch + * 1) NullOpt if there is no feedback hint + * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc + * 3) relay::Function if `mod` should be dispatched to BYOC workflow + * 4) IRModule for unified dispatch */ static Optional QueryInsideWithScope(runtime::String task_name, IRModule mod, Optional> dispatched);