Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
"""
import logging
import os.path
from typing import Any, Optional, Dict, List, Union, Callable
from typing import Any, Optional, Dict, List, Union, Callable, Sequence
from pathlib import Path

import tvm
from tvm import autotvm, auto_scheduler
from tvm import relay
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
from tvm.ir.instrument import PassInstrument
from tvm.ir.memory_pools import WorkspaceMemoryPools
from tvm.target import Target
from tvm.relay.backend import Executor, Runtime
Expand Down Expand Up @@ -223,6 +224,7 @@ def compile_model(
use_vm: bool = False,
mod_name: Optional[str] = "default",
workspace_pools: Optional[WorkspaceMemoryPools] = None,
instruments: Optional[Sequence[PassInstrument]] = None,
):
"""Compile a model from a supported framework into a TVM module.

Expand Down Expand Up @@ -277,6 +279,8 @@ def compile_model(
workspace_pools: WorkspaceMemoryPools, optional
Specification of WorkspacePoolInfo objects to be used as workspace memory in the
compilation.
instruments: Optional[Sequence[PassInstrument]]
The list of pass instrument implementations.

Returns
-------
Expand Down Expand Up @@ -316,7 +320,10 @@ def compile_model(
with auto_scheduler.ApplyHistoryBest(tuning_records):
config["relay.backend.use_auto_scheduler"] = True
with tvm.transform.PassContext(
opt_level=opt_level, config=config, disabled_pass=disabled_pass
opt_level=opt_level,
config=config,
disabled_pass=disabled_pass,
instruments=instruments,
):
logger.debug("building relay graph with autoscheduler")
graph_module = build(
Expand All @@ -332,7 +339,10 @@ def compile_model(
else:
with autotvm.apply_history_best(tuning_records):
with tvm.transform.PassContext(
opt_level=opt_level, config=config, disabled_pass=disabled_pass
opt_level=opt_level,
config=config,
disabled_pass=disabled_pass,
instruments=instruments,
):
logger.debug("building relay graph with tuning records")
graph_module = build(
Expand All @@ -347,7 +357,7 @@ def compile_model(
)
else:
with tvm.transform.PassContext(
opt_level=opt_level, config=config, disabled_pass=disabled_pass
opt_level=opt_level, config=config, disabled_pass=disabled_pass, instruments=instruments
):
logger.debug("building relay graph (no tuning records provided)")
graph_module = build(
Expand Down
23 changes: 23 additions & 0 deletions tests/python/driver/tvmc/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe, mock
config={"relay.ext.mock.options": {"testopt": "value"}},
opt_level=3,
disabled_pass=None,
instruments=None,
)
mock_pc.assert_has_calls(
[
Expand Down Expand Up @@ -697,5 +698,27 @@ def test_compile_check_workspace_pools(mock_pkg, mock_fe, mock_relay):
assert mock_relay.call_args_list[0][1]["workspace_memory_pools"] == memory_pools


def test_compile_check_pass_instrument(keras_resnet50):
pytest.importorskip("tensorflow")

@tvm.instrument.pass_instrument
class PassesCounter:
def __init__(self):
self.run_before_count = 0
self.run_after_count = 0

def run_before_pass(self, mod, info):
self.run_before_count = self.run_before_count + 1

def run_after_pass(self, mod, info):
self.run_after_count = self.run_after_count + 1

passes_counter = PassesCounter()
tvmc_model = tvmc.load(keras_resnet50)
tvmc.compile(tvmc_model, target="llvm", instruments=[passes_counter])
assert passes_counter.run_after_count > 0
assert passes_counter.run_after_count == passes_counter.run_before_count


if __name__ == "__main__":
tvm.testing.main()