diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 864c3a9bddb4..77ba1cb47cc8 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -29,7 +29,7 @@ from tvm import relay from tvm import transform - +from tvm._ffi import registry # pylint: disable=invalid-name logger = logging.getLogger("TVMC") @@ -333,6 +333,37 @@ def tracker_host_port_from_cli(rpc_tracker_str): return rpc_hostname, rpc_port +def parse_pass_list_str(input_string): + """Parse an input string for existing passes + + Parameters + ---------- + input_string: str + Possibly comma-separated string with the names of passes + + Returns + ------- + list: a list of existing passes. + """ + _prefix = "relay._transform." + pass_list = input_string.split(",") + missing_list = [ + p.strip() + for p in pass_list + if len(p.strip()) > 0 and tvm.get_global_func(_prefix + p.strip(), True) is None + ] + if len(missing_list) > 0: + available_list = [ + n[len(_prefix) :] for n in registry.list_global_func_names() if n.startswith(_prefix) + ] + raise argparse.ArgumentTypeError( + "Following passes are not registered within tvm: {}. Available: {}.".format( + ", ".join(missing_list), ", ".join(sorted(available_list)) + ) + ) + return pass_list + + def parse_shape_string(inputs_string): """Parse an input shape dictionary string to a usable dictionary. diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index b8450750f115..4f485122aaf2 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -95,6 +95,12 @@ def add_compile_parser(subparsers): type=common.parse_shape_string, default=None, ) + parser.add_argument( + "--disable-pass", + help="disable specific passes, comma-separated list of pass names", + type=common.parse_pass_list_str, + default="", + ) def drive_compile(args): @@ -121,6 +127,7 @@ def drive_compile(args): None, args.tuning_records, args.desired_layout, + args.disabled_pass, ) if dumps: @@ -138,6 +145,7 @@ def compile_model( target_host=None, tuning_records=None, alter_layout=None, + disabled_pass=None, ): """Compile a model from a supported framework into a TVM module. @@ -167,6 +175,10 @@ def compile_model( The layout to convert the graph to. Note, the convert layout pass doesn't currently guarantee the whole of the graph will be converted to the chosen layout. + disabled_pass: str, optional + Comma-separated list of passes which needs to be disabled + during compilation + Returns ------- @@ -209,16 +221,20 @@ def compile_model( if use_autoscheduler: with auto_scheduler.ApplyHistoryBest(tuning_records): config["relay.backend.use_auto_scheduler"] = True - with tvm.transform.PassContext(opt_level=3, config=config): + with tvm.transform.PassContext( + opt_level=3, config=config, disabled_pass=disabled_pass + ): logger.debug("building relay graph with autoscheduler") graph_module = relay.build(mod, target=target, params=params) else: with autotvm.apply_history_best(tuning_records): - with tvm.transform.PassContext(opt_level=3, config=config): + with tvm.transform.PassContext( + opt_level=3, config=config, disabled_pass=disabled_pass + ): logger.debug("building relay graph with tuning records") graph_module = relay.build(mod, tvm_target, params=params) else: - with tvm.transform.PassContext(opt_level=3, config=config): + with tvm.transform.PassContext(opt_level=3, config=config, disabled_pass=disabled_pass): logger.debug("building relay graph (no tuning records provided)") graph_module = relay.build(mod, tvm_target, params=params) diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py new file mode 100644 index 000000000000..5cac6a1378a5 --- /dev/null +++ b/tests/python/driver/tvmc/test_common.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import argparse +import pytest +from tvm.driver import tvmc + + +def test_common_parse_pass_list_str(): + assert [""] == tvmc.common.parse_pass_list_str("") + assert ["FoldScaleAxis", "FuseOps"] == tvmc.common.parse_pass_list_str("FoldScaleAxis,FuseOps") + + with pytest.raises(argparse.ArgumentTypeError) as ate: + tvmc.common.parse_pass_list_str("MyYobaPass,MySuperYobaPass,FuseOps") + + assert "MyYobaPass" in str(ate.value) + assert "MySuperYobaPass" in str(ate.value) + assert "FuseOps" in str(ate.value) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 8cd77b8cde4a..24fa452d05c1 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -251,5 +251,7 @@ def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_ graph, lib, params, dumps = tvmc.compile(mod, params, target="mockcodegen -testopt=value, llvm") mock_pc.assert_called_once_with( - opt_level=3, config={"relay.ext.mock.options": {"testopt": "value"}} + opt_level=3, + config={"relay.ext.mock.options": {"testopt": "value"}}, + disabled_pass=None, )