From 545781ef7c34f51f0aae00206fc9a762a0947229 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Mon, 26 Dec 2022 16:35:58 +0300 Subject: [PATCH 1/2] [MetaSchedule] Add "disabled_pass" option in tuning API Now there is no way to disable passes in MetaShedule tuner. This commit adds new parameter "disabled_pass" in tuning API (tune_relay/compile_relay). It can be used for different experiments and non default behavoir. --- python/tvm/meta_schedule/relay_integration.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 0b8705aafea9..876dba106c38 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -17,7 +17,7 @@ """MetaSchedule-Relay integration""" from contextlib import contextmanager from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, Set # isort: off from typing_extensions import Literal @@ -120,6 +120,7 @@ def extract_tasks( ), executor: Optional["relay.backend.Executor"] = None, module_equality: str = "structural", + disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -147,6 +148,8 @@ def extract_tasks( given module. The "ignore-ndarray" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. + disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of disabled passes Returns ------- @@ -171,6 +174,7 @@ def extract_tasks( with transform.PassContext( opt_level=opt_level, config=pass_config, + disabled_pass=disabled_pass, ): return list(_extract_task(mod, target, params, module_equality)) @@ -250,6 +254,7 @@ def tune_relay( seed: Optional[int] = None, module_equality: str = "structural", num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", + disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ) -> Database: """Tune a Relay program. @@ -299,6 +304,8 @@ def tune_relay( For the definition of the anchor block, see tir/analysis/analysis.py. num_tuning_cores : Union[Literal["physical", "logical"], int] The number of CPU cores to use during tuning. + disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of disabled passes during tasks extraction Returns ------- @@ -306,7 +313,9 @@ def tune_relay( The database that contains the tuning records """ tasks, task_weights = extracted_tasks_to_tune_contexts( - extracted_tasks=extract_tasks(mod, target, params, module_equality=module_equality), + extracted_tasks=extract_tasks( + mod, target, params, module_equality=module_equality, disabled_pass=disabled_pass + ), work_dir=work_dir, space=space, strategy=strategy, @@ -345,6 +354,7 @@ def compile_relay( } ), executor: Optional["relay.backend.Executor"] = None, + disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ): """Compile a relay program with a MetaSchedule database. @@ -368,6 +378,8 @@ def compile_relay( The pass configuration executor : Optional[relay.backend.Executor] The executor to use in relay.build. It is not supported by RelayVM. + disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of disabled passes Returns ------- @@ -387,6 +399,7 @@ def compile_relay( with transform.PassContext( opt_level=opt_level, config=pass_config, + disabled_pass=disabled_pass, ): if backend == "graph": return relay.build(mod, target=target, params=params, executor=executor) From d7702d26abee16ce90ff6136f94878eaeb38842f Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Wed, 28 Dec 2022 17:48:00 +0300 Subject: [PATCH 2/2] Add unit test for 'disabled_pass' parameter in MetaScheduler tuner This commit adds unit test for 'disabled_pass' parameter in MetaSchedule tuner. Test should throw TVMError exception. That's why it is marked as XFAIL. --- .../test_meta_schedule_relay_integration.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 76d6323f309a..d3731cfa1be8 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -826,5 +826,50 @@ def test_anchor_tuning_cpu_link_params(): np.testing.assert_allclose(ref, out, atol=1e-3) +@pytest.mark.xfail(raises=tvm.error.TVMError) +def test_disabled_pass_param(): + """ + Check 'disabled_pass' parameter in tune_relay. Should throw exception in + case of correct work. + """ + data_shape = [1, 4, 16, 16] + weight_shape = [32, 4, 2, 2] + + data = relay.var("data", shape=data_shape, dtype="uint8") + weight = relay.var("weight", shape=weight_shape, dtype="int8") + + op = relay.qnn.op.conv2d( + data, + weight, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(0.7), + kernel_scale=relay.const(0.3), + kernel_size=[2, 2], + channels=32, + ) + mod = tvm.IRModule.from_expr(op) + + weight_np = np.random.randint(-10, 10, size=weight_shape).astype("int8") + params = {"weight": weight_np} + + executor = relay.backend.Executor("graph", {"link-params": True}) + mod = mod.with_attr("executor", executor) + + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target="llvm --num-cores=4", + params=params, + work_dir=work_dir, + max_trials_global=4, + strategy="replay-trace", + disabled_pass=["qnn.Legalize"], + ) + + # Test failed, otherwise we can not reach this point. + pytest.fail("'disabled_pass' argument does not work") + + if __name__ == "__main__": tvm.testing.main()