diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index f9ba427ffaa6..98293e596b5d 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -389,110 +389,115 @@ def tune_model( # model is fixed. For now, creating a clone avoids the issue. mod = deepcopy(tvmc_model.mod) params = tvmc_model.params - if tuning_records is None: - tuning_records = tvmc_model.default_tuning_records_path() - - for codegen_from_cli in extra_targets: - codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) - partition_function = codegen["pass_pipeline"] - mod = partition_function(mod, params, **codegen_from_cli["opts"]) - - # min_repeat_ms should be: - # a. the value provided by the user, if any, or - # b. 0ms in case target is "cpu"; otherwise 1000ms - if min_repeat_ms is None: - min_repeat_ms = 0 if target.keys[0] == "cpu" else 1000 - logger.info("Default --min-repeat-ms for this target is %s", min_repeat_ms) - - if rpc_key: - if hostname is None or port is None: - raise TVMCException( - "You must provide a hostname and port to connect to a remote RPC device." - ) - if isinstance(port, str): - port = int(port) - - logger.info("Tuning will be performed on device %s at %s:%d.", rpc_key, hostname, port) - - runner_ctor = auto_scheduler.RPCRunner if enable_autoscheduler else autotvm.RPCRunner - runner = runner_ctor( - key=rpc_key, - host=hostname, - port=port, - number=number, - repeat=repeat, - n_parallel=parallel, - timeout=timeout, - min_repeat_ms=min_repeat_ms, - ) - else: - logger.info("Starting localhost tuning.") - runner_ctor = ( - auto_scheduler.LocalRPCMeasureContext if enable_autoscheduler else autotvm.LocalRunner - ) - local_server = runner_ctor( - number=number, - repeat=repeat, - timeout=timeout, - min_repeat_ms=min_repeat_ms, - ) - # For autoscheduling on some devices, we need to maintain a LocalRPCMeasureContext object. - if enable_autoscheduler: - runner = local_server.runner + with tvm.transform.PassContext(opt_level=3): + if tuning_records is None: + tuning_records = tvmc_model.default_tuning_records_path() + + for codegen_from_cli in extra_targets: + codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) + partition_function = codegen["pass_pipeline"] + mod = partition_function(mod, params, **codegen_from_cli["opts"]) + + # min_repeat_ms should be: + # a. the value provided by the user, if any, or + # b. 0ms in case target is "cpu"; otherwise 1000ms + if min_repeat_ms is None: + min_repeat_ms = 0 if target.keys[0] == "cpu" else 1000 + logger.info("Default --min-repeat-ms for this target is %s", min_repeat_ms) + + if rpc_key: + if hostname is None or port is None: + raise TVMCException( + "You must provide a hostname and port to connect to a remote RPC device." + ) + if isinstance(port, str): + port = int(port) + + logger.info("Tuning will be performed on device %s at %s:%d.", rpc_key, hostname, port) + + runner_ctor = auto_scheduler.RPCRunner if enable_autoscheduler else autotvm.RPCRunner + runner = runner_ctor( + key=rpc_key, + host=hostname, + port=port, + number=number, + repeat=repeat, + n_parallel=parallel, + timeout=timeout, + min_repeat_ms=min_repeat_ms, + ) else: - runner = local_server + logger.info("Starting localhost tuning.") + runner_ctor = ( + auto_scheduler.LocalRPCMeasureContext + if enable_autoscheduler + else autotvm.LocalRunner + ) + local_server = runner_ctor( + number=number, + repeat=repeat, + timeout=timeout, + min_repeat_ms=min_repeat_ms, + ) - if enable_autoscheduler: + # For autoscheduling on some devices, we need to maintain a + # LocalRPCMeasureContext object. + if enable_autoscheduler: + runner = local_server.runner + else: + runner = local_server - tasks, weights = autoscheduler_get_tuning_tasks( - mod=mod, - params=params, - target=target, - alter_layout=desired_layout, - hardware_params=hardware_params, - include_simple_tasks=include_simple_tasks, - ) + if enable_autoscheduler: - # Create the autoscheduler tuning options - tuning_options = auto_scheduler.TuningOptions( - num_measure_trials=trials, - measure_callbacks=[auto_scheduler.RecordToFile(tuning_records)], - runner=runner, - early_stopping=early_stopping, - ) + tasks, weights = autoscheduler_get_tuning_tasks( + mod=mod, + params=params, + target=target, + alter_layout=desired_layout, + hardware_params=hardware_params, + include_simple_tasks=include_simple_tasks, + ) + + # Create the autoscheduler tuning options + tuning_options = auto_scheduler.TuningOptions( + num_measure_trials=trials, + measure_callbacks=[auto_scheduler.RecordToFile(tuning_records)], + runner=runner, + early_stopping=early_stopping, + ) - logger.info("Autoscheduling with configuration: %s", tuning_options) + logger.info("Autoscheduling with configuration: %s", tuning_options) - # Schedule the tasks (i.e., produce a schedule for each task) - schedule_tasks(tasks, weights, tuning_options, prior_records, log_estimated_latency) - else: - tasks = autotvm_get_tuning_tasks( - mod=mod, - params=params, - target=target, - alter_layout=desired_layout, - ) + # Schedule the tasks (i.e., produce a schedule for each task) + schedule_tasks(tasks, weights, tuning_options, prior_records, log_estimated_latency) + else: + tasks = autotvm_get_tuning_tasks( + mod=mod, + params=params, + target=target, + alter_layout=desired_layout, + ) - # In autotvm, trials is specified per task. We can convert the per-model input - # provided to per-task trials by dividing by the number of tasks. - trials = int(trials / max(len(tasks), 1)) - logger.info("Autotuning with %d trials per task.", trials) - - tuning_options = { - "tuner": tuner, - "trials": trials, - "early_stopping": early_stopping, - "measure_option": autotvm.measure_option( - builder=autotvm.LocalBuilder(build_func="default"), runner=runner - ), - "tuning_records": prior_records, - } - logger.info("Autotuning with configuration: %s", tuning_options) - - tune_tasks(tasks, tuning_records, **tuning_options) - - return tuning_records + # In autotvm, trials is specified per task. We can convert the per-model input + # provided to per-task trials by dividing by the number of tasks. + trials = int(trials / max(len(tasks), 1)) + logger.info("Autotuning with %d trials per task.", trials) + + tuning_options = { + "tuner": tuner, + "trials": trials, + "early_stopping": early_stopping, + "measure_option": autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func="default"), runner=runner + ), + "tuning_records": prior_records, + } + logger.info("Autotuning with configuration: %s", tuning_options) + + tune_tasks(tasks, tuning_records, **tuning_options) + + return tuning_records def autotvm_get_tuning_tasks( diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index c24d36c432df..eec80820cdb1 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -292,39 +292,42 @@ def compile_model( config = parse_configs(pass_context_configs) - if desired_layout: - mod = convert_graph_layout(mod, desired_layout) - tvm_target, extra_targets = target_from_cli(target, additional_target_options) tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host) + partition_functions = [] + partition_opts = [] for codegen_from_cli in extra_targets: codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) - partition_function = codegen["pass_pipeline"] - + partition_functions.append(codegen["pass_pipeline"]) + partition_opts.append(codegen_from_cli["opts"]) if codegen["config_key"] is not None: config[codegen["config_key"]] = codegen_from_cli["opts"] - with tvm.transform.PassContext(config=config): - mod = partition_function(mod, params, mod_name=mod_name, **codegen_from_cli["opts"]) - - if tuning_records and os.path.exists(tuning_records): - logger.debug("tuning records file provided: %s", tuning_records) - - use_autoscheduler = True - try: - auto_scheduler.load_records(tuning_records) - except tvm._ffi.base.TVMError: - use_autoscheduler = False - - if use_autoscheduler: - with auto_scheduler.ApplyHistoryBest(tuning_records): - config["relay.backend.use_auto_scheduler"] = True - with tvm.transform.PassContext( - opt_level=opt_level, - config=config, - disabled_pass=disabled_pass, - instruments=instruments, - ): + + with tvm.transform.PassContext( + opt_level=opt_level, + config=config, + disabled_pass=disabled_pass, + instruments=instruments, + ): + if desired_layout: + mod = convert_graph_layout(mod, desired_layout) + + for partition_function, opts in zip(partition_functions, partition_opts): + mod = partition_function(mod, params, mod_name=mod_name, **opts) + + if tuning_records and os.path.exists(tuning_records): + logger.debug("tuning records file provided: %s", tuning_records) + + use_autoscheduler = True + try: + auto_scheduler.load_records(tuning_records) + except tvm._ffi.base.TVMError: + use_autoscheduler = False + + if use_autoscheduler: + with auto_scheduler.ApplyHistoryBest(tuning_records): + config["relay.backend.use_auto_scheduler"] = True logger.debug("building relay graph with autoscheduler") graph_module = build( mod, @@ -336,14 +339,8 @@ def compile_model( mod_name=mod_name, workspace_pools=workspace_pools, ) - else: - with autotvm.apply_history_best(tuning_records): - with tvm.transform.PassContext( - opt_level=opt_level, - config=config, - disabled_pass=disabled_pass, - instruments=instruments, - ): + else: + with autotvm.apply_history_best(tuning_records): logger.debug("building relay graph with tuning records") graph_module = build( mod, @@ -355,10 +352,7 @@ def compile_model( mod_name=mod_name, workspace_pools=workspace_pools, ) - else: - with tvm.transform.PassContext( - opt_level=opt_level, config=config, disabled_pass=disabled_pass, instruments=instruments - ): + else: logger.debug("building relay graph (no tuning records provided)") graph_module = build( mod, @@ -371,32 +365,32 @@ def compile_model( workspace_pools=workspace_pools, ) - # Generate output dump files with sources - if dump_code is None: - dump_code = [] - if not isinstance(dump_code, list): - dump_code = [dump_code] - dumps = {} - for source_type in dump_code: - if use_vm: - lib = graph_module.lib - else: - lib = graph_module.get_lib() - # TODO lib.get_source call have inconsistent behavior for unsupported - # formats (@leandron). - source = str(mod) if source_type == "relay" else lib.get_source(source_type) - dumps[source_type] = source - - # Create a new tvmc model package object from the graph definition. - package_path = tvmc_model.export_package( - graph_module, package_path, cross, cross_options, output_format - ) + # Generate output dump files with sources + if dump_code is None: + dump_code = [] + if not isinstance(dump_code, list): + dump_code = [dump_code] + dumps = {} + for source_type in dump_code: + if use_vm: + lib = graph_module.lib + else: + lib = graph_module.get_lib() + # TODO lib.get_source call have inconsistent behavior for unsupported + # formats (@leandron). + source = str(mod) if source_type == "relay" else lib.get_source(source_type) + dumps[source_type] = source + + # Create a new tvmc model package object from the graph definition. + package_path = tvmc_model.export_package( + graph_module, package_path, cross, cross_options, output_format + ) - # Write dumps to file. - if dumps: - save_dumps(package_path, dumps) + # Write dumps to file. + if dumps: + save_dumps(package_path, dumps) - return TVMCPackage(package_path) + return TVMCPackage(package_path) def build( diff --git a/python/tvm/driver/tvmc/transform.py b/python/tvm/driver/tvmc/transform.py index 51c9e52f21d6..8527c48b6b04 100644 --- a/python/tvm/driver/tvmc/transform.py +++ b/python/tvm/driver/tvmc/transform.py @@ -54,10 +54,7 @@ def convert_graph_layout(mod, desired_layout): ] ) - with transform.PassContext(opt_level=3): - try: - return seq(mod) - except Exception as err: - raise TVMCException( - "Error converting layout to {0}: {1}".format(desired_layout, str(err)) - ) + try: + return seq(mod) + except Exception as err: + raise TVMCException("Error converting layout to {0}: {1}".format(desired_layout, str(err))) diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 8009448bff77..e0dbeebf9871 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -23,6 +23,8 @@ from PIL import Image +import tvm +from tvm import relay from tvm.driver import tvmc from tvm.contrib.download import download_testdata @@ -284,3 +286,17 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), with open(file_path, "w") as relay_text: relay_text.write(RELAY_MODEL) return file_path + + +@pytest.fixture(scope="session") +def relay_conv2d(): + """ + Simple conv2d Relay implementation. + """ + dtype = "float32" + + x = relay.var("x", shape=(1, 4, 2, 2), dtype=dtype) + weight = relay.const(np.random.uniform(size=(2, 4, 2, 2)), dtype=dtype) + x = relay.nn.conv2d(x, weight) + func = relay.Function(relay.analysis.free_vars(x), x) + return tvm.IRModule.from_expr(func) diff --git a/tests/python/driver/tvmc/test_autotuner.py b/tests/python/driver/tvmc/test_autotuner.py index 7c05ff804fa4..eb6550e40cdc 100644 --- a/tests/python/driver/tvmc/test_autotuner.py +++ b/tests/python/driver/tvmc/test_autotuner.py @@ -23,6 +23,7 @@ from os import path from pathlib import Path +import tvm from tvm import autotvm from tvm.driver import tvmc @@ -191,3 +192,18 @@ def test_tune_rpc_tracker_parsing(mock_load_model, mock_tune_model, mock_auto_sc assert "10.0.0.1" == kwargs["hostname"] assert "port" in kwargs assert 9999 == kwargs["port"] + + +@mock.patch("tvm.transform.PassContext", return_value=tvm.transform.PassContext()) +def test_autotune_pass_context(mock_pc, onnx_mnist, tmpdir_factory): + """ + Check that the pass context while tuning is as expected. + """ + pytest.importorskip("onnx") + + tmpdir_name = tmpdir_factory.mktemp("data") + _tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name) + + # AutoTVM overrides the pass context later in the pipeline to disable AlterOpLayout + assert mock_pc.call_count == 2 + assert mock_pc.call_args_list[0][1]["opt_level"] == 3 diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 7cb50dd0e366..3a3f297729fd 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -508,10 +508,7 @@ def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe, mock tvmc_model = tvmc.load("no_file_needed") tvmc.compile(tvmc_model, target="mockcodegen -testopt=value, llvm") - assert mock_pc.call_count == 2 - codegen_partition_context = mock.call( - config={"relay.ext.mock.options": {"testopt": "value"}}, - ) + assert mock_pc.call_count == 1 codegen_compile_context = mock.call( config={"relay.ext.mock.options": {"testopt": "value"}}, opt_level=3, @@ -520,9 +517,6 @@ def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe, mock ) mock_pc.assert_has_calls( [ - codegen_partition_context, - codegen_partition_context.__enter__(), - codegen_partition_context.__exit__(None, None, None), codegen_compile_context, codegen_compile_context.__enter__(), codegen_compile_context.__exit__(None, None, None), diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index c1a3be67c208..718babd15c29 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -297,7 +297,8 @@ def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): before = tvmc_model.mod expected_layout = "NCHW" - after = tvmc.transform.convert_graph_layout(before, expected_layout) + with tvm.transform.PassContext(opt_level=3): + after = tvmc.transform.convert_graph_layout(before, expected_layout) layout_transform_calls = [] @@ -322,7 +323,8 @@ def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): before = tvmc_model.mod expected_layout = "NHWC" - after = tvmc.transform.convert_graph_layout(before, expected_layout) + with tvm.transform.PassContext(opt_level=3): + after = tvmc.transform.convert_graph_layout(before, expected_layout) layout_transform_calls = [] @@ -347,7 +349,8 @@ def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50): before = tvmc_model.mod expected_layout = "NHWC" - after = tvmc.transform.convert_graph_layout(before, expected_layout) + with tvm.transform.PassContext(opt_level=3): + after = tvmc.transform.convert_graph_layout(before, expected_layout) layout_transform_calls = [] @@ -372,7 +375,9 @@ def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_ before = tvmc_model.mod expected_layout = "NHWC" - after = tvmc.transform.convert_graph_layout(before, expected_layout) + + with tvm.transform.PassContext(opt_level=3): + after = tvmc.transform.convert_graph_layout(before, expected_layout) layout_transform_calls = [] @@ -397,7 +402,9 @@ def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): before = tvmc_model.mod expected_layout = "NCHW" - after = tvmc.transform.convert_graph_layout(before, expected_layout) + + with tvm.transform.PassContext(opt_level=3): + after = tvmc.transform.convert_graph_layout(before, expected_layout) layout_transform_calls = [] diff --git a/tests/python/driver/tvmc/test_transform.py b/tests/python/driver/tvmc/test_transform.py index 98a0210a1bb6..98bd3b5f98a3 100644 --- a/tests/python/driver/tvmc/test_transform.py +++ b/tests/python/driver/tvmc/test_transform.py @@ -14,43 +14,60 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest -import numpy as np + +from unittest.mock import MagicMock import tvm from tvm import relay +from tvm.ir.instrument import pass_instrument from tvm.driver.tvmc.transform import convert_graph_layout -def test_layout_transform(): +def test_layout_transform_fold_constant(relay_conv2d): """ Test layout is correctly transformed and constant folding is applied. """ - dtype = "int8" - iinfo = np.iinfo(dtype) - data_min = iinfo.min - data_max = iinfo.max - - x = relay.var("x", shape=(1, 4, 2, 2), dtype=dtype) - weight = relay.const( - np.random.randint(data_min, data_max, size=(2, 4, 2, 2), dtype=dtype), dtype=dtype - ) - x = relay.nn.conv2d(x, weight) - func = relay.Function(relay.analysis.free_vars(x), x) - mod = tvm.IRModule.from_expr(func) + desired_layout = "NHWC" + + @pass_instrument + class CollectPassNames: + def __init__(self): + self.names = [] + + def run_after_pass(self, _, info): + self.names.append(info.name) + + pass_names = CollectPassNames() + with tvm.transform.PassContext(opt_level=3, instruments=[pass_names]): + convert_graph_layout(relay_conv2d, desired_layout) + names = pass_names.names + assert "ConvertLayout" in names + assert "FoldConstant" in names + assert names.index("ConvertLayout") < names.index("FoldConstant") + + +def test_layout_transform_convert_layout_pass_args(relay_conv2d, monkeypatch): + """ + Check the convert layout desired layouts arugment is what is expected when + a desired layout is provided. + """ desired_layout = "NHWC" - mod = convert_graph_layout(mod, desired_layout) - main_expr = mod["main"].body - conv = main_expr.args[0] - assert conv.op.name == "nn.conv2d" - assert conv.attrs["data_layout"] == "NHWC" - assert conv.attrs["kernel_layout"] == "HWIO" + mock_convert_layout = MagicMock() + mock_convert_layout.return_value = relay.transform.ConvertLayout({}) + monkeypatch.setattr(relay.transform, "ConvertLayout", mock_convert_layout) + + with tvm.transform.PassContext(opt_level=3): + convert_graph_layout(relay_conv2d, desired_layout) - # Ensure transform has been folded into the constant - weights = conv.args[1] - assert isinstance(weights, relay.expr.Constant) + mock_convert_layout.assert_called_once_with( + { + "nn.conv2d": ["NHWC", "default"], + "nn.conv2d_transpose": ["NHWC", "default"], + "qnn.conv2d": ["NHWC", "default"], + } + ) if __name__ == "__main__":