diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 2955df55432d..c24d36c432df 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -19,13 +19,14 @@ """ import logging import os.path -from typing import Any, Optional, Dict, List, Union, Callable +from typing import Any, Optional, Dict, List, Union, Callable, Sequence from pathlib import Path import tvm from tvm import autotvm, auto_scheduler from tvm import relay from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity +from tvm.ir.instrument import PassInstrument from tvm.ir.memory_pools import WorkspaceMemoryPools from tvm.target import Target from tvm.relay.backend import Executor, Runtime @@ -223,6 +224,7 @@ def compile_model( use_vm: bool = False, mod_name: Optional[str] = "default", workspace_pools: Optional[WorkspaceMemoryPools] = None, + instruments: Optional[Sequence[PassInstrument]] = None, ): """Compile a model from a supported framework into a TVM module. @@ -277,6 +279,8 @@ def compile_model( workspace_pools: WorkspaceMemoryPools, optional Specification of WorkspacePoolInfo objects to be used as workspace memory in the compilation. + instruments: Optional[Sequence[PassInstrument]] + The list of pass instrument implementations. Returns ------- @@ -316,7 +320,10 @@ def compile_model( with auto_scheduler.ApplyHistoryBest(tuning_records): config["relay.backend.use_auto_scheduler"] = True with tvm.transform.PassContext( - opt_level=opt_level, config=config, disabled_pass=disabled_pass + opt_level=opt_level, + config=config, + disabled_pass=disabled_pass, + instruments=instruments, ): logger.debug("building relay graph with autoscheduler") graph_module = build( @@ -332,7 +339,10 @@ def compile_model( else: with autotvm.apply_history_best(tuning_records): with tvm.transform.PassContext( - opt_level=opt_level, config=config, disabled_pass=disabled_pass + opt_level=opt_level, + config=config, + disabled_pass=disabled_pass, + instruments=instruments, ): logger.debug("building relay graph with tuning records") graph_module = build( @@ -347,7 +357,7 @@ def compile_model( ) else: with tvm.transform.PassContext( - opt_level=opt_level, config=config, disabled_pass=disabled_pass + opt_level=opt_level, config=config, disabled_pass=disabled_pass, instruments=instruments ): logger.debug("building relay graph (no tuning records provided)") graph_module = build( diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 5535fc02249f..7cb50dd0e366 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -516,6 +516,7 @@ def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe, mock config={"relay.ext.mock.options": {"testopt": "value"}}, opt_level=3, disabled_pass=None, + instruments=None, ) mock_pc.assert_has_calls( [ @@ -697,5 +698,27 @@ def test_compile_check_workspace_pools(mock_pkg, mock_fe, mock_relay): assert mock_relay.call_args_list[0][1]["workspace_memory_pools"] == memory_pools +def test_compile_check_pass_instrument(keras_resnet50): + pytest.importorskip("tensorflow") + + @tvm.instrument.pass_instrument + class PassesCounter: + def __init__(self): + self.run_before_count = 0 + self.run_after_count = 0 + + def run_before_pass(self, mod, info): + self.run_before_count = self.run_before_count + 1 + + def run_after_pass(self, mod, info): + self.run_after_count = self.run_after_count + 1 + + passes_counter = PassesCounter() + tvmc_model = tvmc.load(keras_resnet50) + tvmc.compile(tvmc_model, target="llvm", instruments=[passes_counter]) + assert passes_counter.run_after_count > 0 + assert passes_counter.run_after_count == passes_counter.run_before_count + + if __name__ == "__main__": tvm.testing.main()