From bc1d14a00eab14df4783205852d8d0bd2ef3ade7 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 19 Oct 2022 17:08:38 +0800 Subject: [PATCH 1/3] [tvmc] add instruments for PassContext --- python/tvm/driver/tvmc/compiler.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 2955df55432d..2af7be4c20cf 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -223,6 +223,7 @@ def compile_model( use_vm: bool = False, mod_name: Optional[str] = "default", workspace_pools: Optional[WorkspaceMemoryPools] = None, + instruments = None, ): """Compile a model from a supported framework into a TVM module. @@ -277,6 +278,8 @@ def compile_model( workspace_pools: WorkspaceMemoryPools, optional Specification of WorkspacePoolInfo objects to be used as workspace memory in the compilation. + instruments: Optional[Sequence[PassInstrument]] + The list of pass instrument implementations. Returns ------- @@ -316,7 +319,7 @@ def compile_model( with auto_scheduler.ApplyHistoryBest(tuning_records): config["relay.backend.use_auto_scheduler"] = True with tvm.transform.PassContext( - opt_level=opt_level, config=config, disabled_pass=disabled_pass + opt_level=opt_level, config=config, disabled_pass=disabled_pass, instruments=instruments ): logger.debug("building relay graph with autoscheduler") graph_module = build( @@ -332,7 +335,7 @@ def compile_model( else: with autotvm.apply_history_best(tuning_records): with tvm.transform.PassContext( - opt_level=opt_level, config=config, disabled_pass=disabled_pass + opt_level=opt_level, config=config, disabled_pass=disabled_pass, instruments=instruments ): logger.debug("building relay graph with tuning records") graph_module = build( @@ -347,7 +350,7 @@ def compile_model( ) else: with tvm.transform.PassContext( - opt_level=opt_level, config=config, disabled_pass=disabled_pass + opt_level=opt_level, config=config, disabled_pass=disabled_pass, instruments=instruments ): logger.debug("building relay graph (no tuning records provided)") graph_module = build( From 74fd3b91f0a49796f19474cf97f3270914cea765 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 19 Oct 2022 19:25:59 +0800 Subject: [PATCH 2/3] fix review --- python/tvm/driver/tvmc/compiler.py | 15 +++++++++++---- tests/python/driver/tvmc/test_compiler.py | 22 ++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 2af7be4c20cf..c24d36c432df 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -19,13 +19,14 @@ """ import logging import os.path -from typing import Any, Optional, Dict, List, Union, Callable +from typing import Any, Optional, Dict, List, Union, Callable, Sequence from pathlib import Path import tvm from tvm import autotvm, auto_scheduler from tvm import relay from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity +from tvm.ir.instrument import PassInstrument from tvm.ir.memory_pools import WorkspaceMemoryPools from tvm.target import Target from tvm.relay.backend import Executor, Runtime @@ -223,7 +224,7 @@ def compile_model( use_vm: bool = False, mod_name: Optional[str] = "default", workspace_pools: Optional[WorkspaceMemoryPools] = None, - instruments = None, + instruments: Optional[Sequence[PassInstrument]] = None, ): """Compile a model from a supported framework into a TVM module. @@ -319,7 +320,10 @@ def compile_model( with auto_scheduler.ApplyHistoryBest(tuning_records): config["relay.backend.use_auto_scheduler"] = True with tvm.transform.PassContext( - opt_level=opt_level, config=config, disabled_pass=disabled_pass, instruments=instruments + opt_level=opt_level, + config=config, + disabled_pass=disabled_pass, + instruments=instruments, ): logger.debug("building relay graph with autoscheduler") graph_module = build( @@ -335,7 +339,10 @@ def compile_model( else: with autotvm.apply_history_best(tuning_records): with tvm.transform.PassContext( - opt_level=opt_level, config=config, disabled_pass=disabled_pass, instruments=instruments + opt_level=opt_level, + config=config, + disabled_pass=disabled_pass, + instruments=instruments, ): logger.debug("building relay graph with tuning records") graph_module = build( diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 5535fc02249f..1ead0fb30a49 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -697,5 +697,27 @@ def test_compile_check_workspace_pools(mock_pkg, mock_fe, mock_relay): assert mock_relay.call_args_list[0][1]["workspace_memory_pools"] == memory_pools +def test_compile_check_pass_instrument(keras_resnet50): + pytest.importorskip("tensorflow") + + @tvm.instrument.pass_instrument + class PassesCounter: + def __init__(self): + self.run_before_count = 0 + self.run_after_count = 0 + + def run_before_pass(self, mod, info): + self.run_before_count = self.run_before_count + 1 + + def run_after_pass(self, mod, info): + self.run_after_count = self.run_after_count + 1 + + passes_counter = PassesCounter() + tvmc_model = tvmc.load(keras_resnet50) + tvmc.compile(tvmc_model, target="llvm", instruments=[passes_counter]) + assert passes_counter.run_after_count > 0 + assert passes_counter.run_after_count == passes_counter.run_before_count + + if __name__ == "__main__": tvm.testing.main() From 5f70d095c5dc6e29b9077623d44225de56572588 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 20 Oct 2022 13:26:16 +0800 Subject: [PATCH 3/3] fix test error --- tests/python/driver/tvmc/test_compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 1ead0fb30a49..7cb50dd0e366 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -516,6 +516,7 @@ def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe, mock config={"relay.ext.mock.options": {"testopt": "value"}}, opt_level=3, disabled_pass=None, + instruments=None, ) mock_pc.assert_has_calls( [