From b26d443d43a75feff061cc53a9814695b7f2a44b Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 22 Nov 2022 09:27:04 -0800 Subject: [PATCH 1/2] [MetaSchedule] Use current pass context in compile_relay, extract_tasks Adds the pass config information necessary for tuning and compiling relay with metaschedule to the existing pass context instead of overriding the existing one. Allows users to pass in their own pass instruments, required passes, and disabled passes. This also keeps the same API used to compile relay with autotvm and auto_scheduler. --- python/tvm/meta_schedule/relay_integration.py | 75 ++++++------------- .../metaschedule_e2e/test_resnet50_int8.py | 6 +- 2 files changed, 22 insertions(+), 59 deletions(-) diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 876dba106c38..8910dc17b202 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -71,15 +71,8 @@ def _normalize_params( mod: IRModule, target: Union[Target, str], params: Optional[Dict[str, NDArray]], - pass_config: Mapping[str, Any], executor: Optional["relay.backend.Executor"], -) -> Tuple[ - IRModule, - Target, - Dict[str, NDArray], - Dict[str, Any], - Optional["relay.backend.Executor"], -]: +) -> Tuple[IRModule, Target, Dict[str, NDArray], Optional["relay.backend.Executor"],]: from tvm import relay # pylint: disable=import-outside-toplevel if isinstance(mod, relay.Function): @@ -102,8 +95,7 @@ def _normalize_params( else: executor = mod.get_attr("executor") - pass_config = dict(pass_config) - return mod, target, relay_params, pass_config, executor + return mod, target, relay_params, executor def extract_tasks( @@ -111,16 +103,8 @@ def extract_tasks( target: Union[Target, str], params: Optional[Dict[str, NDArray]], *, - opt_level: int = 3, - pass_config: Mapping[str, Any] = MappingProxyType( - { - "relay.backend.use_meta_schedule": True, - "relay.backend.tir_converter": "default", - } - ), executor: Optional["relay.backend.Executor"] = None, module_equality: str = "structural", - disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -132,10 +116,6 @@ def extract_tasks( The compilation target params : Optional[Dict[str, tvm.runtime.NDArray]] The associated parameters of the program - opt_level : int - The optimization level of the compilation - pass_config : Mapping[str, Any] - The pass configuration executor : Optional[relay.backend.Executor] The executor to use module_equality : Optional[str] @@ -148,8 +128,6 @@ def extract_tasks( given module. The "ignore-ndarray" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. - disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] - The list of disabled passes Returns ------- @@ -160,21 +138,25 @@ def extract_tasks( from tvm import autotvm # pylint: enable=import-outside-toplevel - mod, target, params, pass_config, _ = _normalize_params( - mod, target, params, pass_config, executor - ) + mod, target, params, _ = _normalize_params(mod, target, params, executor) if target.kind.name != "cuda" and isinstance( autotvm.DispatchContext.current, autotvm.FallbackContext ): tophub_context = autotvm.tophub.context(target) else: tophub_context = autotvm.utils.EmptyContext() + pass_ctx = transform.PassContext.current() + pass_config = dict(pass_ctx.config) + pass_config.setdefault("relay.backend.use_meta_schedule", True) + pass_config.setdefault("relay.backend.tir_converter", "default") with Profiler.timeit("TaskExtraction"): with target, _autotvm_silencer(), tophub_context: with transform.PassContext( - opt_level=opt_level, + opt_level=pass_ctx.opt_level, + required_pass=pass_ctx.required_pass, + disabled_pass=pass_ctx.disabled_pass, + instruments=pass_ctx.instruments, config=pass_config, - disabled_pass=disabled_pass, ): return list(_extract_task(mod, target, params, module_equality)) @@ -254,7 +236,6 @@ def tune_relay( seed: Optional[int] = None, module_equality: str = "structural", num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", - disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ) -> Database: """Tune a Relay program. @@ -304,8 +285,6 @@ def tune_relay( For the definition of the anchor block, see tir/analysis/analysis.py. num_tuning_cores : Union[Literal["physical", "logical"], int] The number of CPU cores to use during tuning. - disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] - The list of disabled passes during tasks extraction Returns ------- @@ -313,9 +292,7 @@ def tune_relay( The database that contains the tuning records """ tasks, task_weights = extracted_tasks_to_tune_contexts( - extracted_tasks=extract_tasks( - mod, target, params, module_equality=module_equality, disabled_pass=disabled_pass - ), + extracted_tasks=extract_tasks(mod, target, params, module_equality=module_equality), work_dir=work_dir, space=space, strategy=strategy, @@ -346,15 +323,7 @@ def compile_relay( params: Optional[Dict[str, NDArray]], *, backend: Literal["graph", "vm"] = "graph", - opt_level: int = 3, - pass_config: Mapping[str, Any] = MappingProxyType( - { - "relay.backend.use_meta_schedule": True, - "relay.backend.tir_converter": "default", - } - ), executor: Optional["relay.backend.Executor"] = None, - disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ): """Compile a relay program with a MetaSchedule database. @@ -372,14 +341,8 @@ def compile_relay( The backend to use. Builtin backends: - "graph" - "vm" - opt_level : int - The optimization level of the compilation - pass_config : Mapping[str, Any] - The pass configuration executor : Optional[relay.backend.Executor] The executor to use in relay.build. It is not supported by RelayVM. - disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] - The list of disabled passes Returns ------- @@ -390,16 +353,20 @@ def compile_relay( from tvm import relay # pylint: enable=import-outside-toplevel - mod, target, params, pass_config, executor = _normalize_params( - mod, target, params, pass_config, executor - ) + mod, target, params, executor = _normalize_params(mod, target, params, executor) + pass_ctx = transform.PassContext.current() + pass_config = dict(pass_ctx.config) pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", True) + pass_config.setdefault("relay.backend.use_meta_schedule", True) + pass_config.setdefault("relay.backend.tir_converter", "default") with Profiler.timeit("PostTuningCompilation"): with target, _autotvm_silencer(), database: with transform.PassContext( - opt_level=opt_level, + opt_level=pass_ctx.opt_level, + required_pass=pass_ctx.required_pass, + disabled_pass=pass_ctx.disabled_pass, + instruments=pass_ctx.instruments, config=pass_config, - disabled_pass=disabled_pass, ): if backend == "graph": return relay.build(mod, target=target, params=params, executor=executor) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 1e01cb28a749..6ec1b4dd81c2 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -347,7 +347,7 @@ def schedule_conv2d_for_tune(sch: Schedule): else None ) - with tempfile.TemporaryDirectory() as work_dir: + with tempfile.TemporaryDirectory() as work_dir, pass_context: database = ms.relay_integration.tune_relay( mod=mod, target=TARGET_HEXAGON, @@ -382,15 +382,11 @@ def schedule_conv2d_for_tune(sch: Schedule): module_equality="ignore-ndarray", ) - # Add default options so that it still uses the base config. - pass_config["relay.backend.use_meta_schedule"] = True - pass_config["relay.backend.tir_converter"] = "default" return ms.relay_integration.compile_relay( database=database, mod=mod, target=TARGET_HEXAGON, params=params, - pass_config=pass_config, ) From 291fc7e4aaa63bdcac580c3c89bcbbe5b734c20b Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 3 Jan 2023 14:51:24 -0800 Subject: [PATCH 2/2] formatting --- python/tvm/meta_schedule/relay_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 8910dc17b202..ee6982bceffa 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -16,14 +16,14 @@ # under the License. """MetaSchedule-Relay integration""" from contextlib import contextmanager -from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union # isort: off from typing_extensions import Literal # isort: on import numpy as np # type: ignore + from tvm import nd from tvm._ffi import get_global_func from tvm.ir import IRModule, transform