From d63f6b48a6728b58b585856102c02841ccb4a1b9 Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Wed, 7 Apr 2021 12:27:12 +0100 Subject: [PATCH 1/4] [TVMC] --disable-pass option added to compile mode Added --disable-pass option to TVMC compile mode to disallow certain supplied passes in PassContext for the compiler. Change-Id: Iae1849d7b051ac9288509dc458a58788c865537a --- python/tvm/driver/tvmc/common.py | 23 +++++++++++++++++++++++ python/tvm/driver/tvmc/compiler.py | 22 +++++++++++++++++++--- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 864c3a9bddb4..e4ff27c6fcd8 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -333,6 +333,29 @@ def tracker_host_port_from_cli(rpc_tracker_str): return rpc_hostname, rpc_port +def parse_disabled_pass(input_string): + """Parse an input string for disabled passes + + Parameters + ---------- + input_string: str + Possibly comma-separated string with the names of disabled passes + + Returns + ------- + list: a list of disabled passes. + """ + if input_string is not None: + pass_list = input_string.split(",") + nf = [_ for _ in pass_list if tvm.get_global_func("relay._transform." + _, True) is None] + if len(nf) > 0: + raise argparse.ArgumentTypeError( + "Following passes are not registered within tvm: " + str(nf) + ) + return pass_list + return None + + def parse_shape_string(inputs_string): """Parse an input shape dictionary string to a usable dictionary. diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index b8450750f115..eff262c0efb7 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -95,6 +95,12 @@ def add_compile_parser(subparsers): type=common.parse_shape_string, default=None, ) + parser.add_argument( + "--disabled-pass", + help="disable specific passes, comma-separated list of pass names", + type=common.parse_disabled_pass, + default=None, + ) def drive_compile(args): @@ -121,6 +127,7 @@ def drive_compile(args): None, args.tuning_records, args.desired_layout, + args.disabled_pass, ) if dumps: @@ -138,6 +145,7 @@ def compile_model( target_host=None, tuning_records=None, alter_layout=None, + disabled_pass=None, ): """Compile a model from a supported framework into a TVM module. @@ -167,6 +175,10 @@ def compile_model( The layout to convert the graph to. Note, the convert layout pass doesn't currently guarantee the whole of the graph will be converted to the chosen layout. + disabled_pass: str, optional + Comma-separated list of passes which needs to be disabled + during compilation + Returns ------- @@ -209,16 +221,20 @@ def compile_model( if use_autoscheduler: with auto_scheduler.ApplyHistoryBest(tuning_records): config["relay.backend.use_auto_scheduler"] = True - with tvm.transform.PassContext(opt_level=3, config=config): + with tvm.transform.PassContext( + opt_level=3, config=config, disabled_pass=disabled_pass + ): logger.debug("building relay graph with autoscheduler") graph_module = relay.build(mod, target=target, params=params) else: with autotvm.apply_history_best(tuning_records): - with tvm.transform.PassContext(opt_level=3, config=config): + with tvm.transform.PassContext( + opt_level=3, config=config, disabled_pass=disabled_pass + ): logger.debug("building relay graph with tuning records") graph_module = relay.build(mod, tvm_target, params=params) else: - with tvm.transform.PassContext(opt_level=3, config=config): + with tvm.transform.PassContext(opt_level=3, config=config, disabled_pass=disabled_pass): logger.debug("building relay graph (no tuning records provided)") graph_module = relay.build(mod, tvm_target, params=params) From 1bc2e7956ceab346e10dcbf915e91a1bde42f5f1 Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Mon, 12 Apr 2021 23:10:38 +0100 Subject: [PATCH 2/4] Added test, addressed requests Change-Id: If688f65441d3aa9967ab823adf899cfc704bd097 --- python/tvm/driver/tvmc/common.py | 28 ++++++++++--------- python/tvm/driver/tvmc/compiler.py | 6 ++-- tests/python/driver/tvmc/test_common.py | 34 +++++++++++++++++++++++ tests/python/driver/tvmc/test_compiler.py | 4 ++- 4 files changed, 55 insertions(+), 17 deletions(-) create mode 100644 tests/python/driver/tvmc/test_common.py diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index e4ff27c6fcd8..9976a999eb3f 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -333,27 +333,29 @@ def tracker_host_port_from_cli(rpc_tracker_str): return rpc_hostname, rpc_port -def parse_disabled_pass(input_string): - """Parse an input string for disabled passes +def parse_pass_list_str(input_string): + """Parse an input string for existing passes Parameters ---------- input_string: str - Possibly comma-separated string with the names of disabled passes + Possibly comma-separated string with the names of passes Returns ------- - list: a list of disabled passes. + list: a list of existing passes. """ - if input_string is not None: - pass_list = input_string.split(",") - nf = [_ for _ in pass_list if tvm.get_global_func("relay._transform." + _, True) is None] - if len(nf) > 0: - raise argparse.ArgumentTypeError( - "Following passes are not registered within tvm: " + str(nf) - ) - return pass_list - return None + pass_list = input_string.split(",") + missing_list = [ + p.strip() + for p in pass_list + if len(p.strip()) > 0 and tvm.get_global_func("relay._transform." + p.strip(), True) is None + ] + if len(missing_list) > 0: + raise argparse.ArgumentTypeError( + "Following passes are not registered within tvm: " + str(missing_list) + ) + return pass_list def parse_shape_string(inputs_string): diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index eff262c0efb7..4f485122aaf2 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -96,10 +96,10 @@ def add_compile_parser(subparsers): default=None, ) parser.add_argument( - "--disabled-pass", + "--disable-pass", help="disable specific passes, comma-separated list of pass names", - type=common.parse_disabled_pass, - default=None, + type=common.parse_pass_list_str, + default="", ) diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py new file mode 100644 index 000000000000..4a3e709afe23 --- /dev/null +++ b/tests/python/driver/tvmc/test_common.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import argparse +import pytest +from tvm.driver import tvmc + + +def test_common_parse_pass_list_str(): + assert [""] == tvmc.common.parse_pass_list_str("") + assert ["FoldScaleAxis", "FuseOps"] == tvmc.common.parse_pass_list_str("FoldScaleAxis,FuseOps") + + with pytest.raises(argparse.ArgumentTypeError) as ate: + tvmc.common.parse_pass_list_str("MyYobaPass,MySuperYobaPass,FuseOps") + assert "MyYobaPass" in str(ate) + assert "MySuperYobaPass" in str(ate) + assert "FuseOps" not in str(ate) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 8cd77b8cde4a..24fa452d05c1 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -251,5 +251,7 @@ def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_ graph, lib, params, dumps = tvmc.compile(mod, params, target="mockcodegen -testopt=value, llvm") mock_pc.assert_called_once_with( - opt_level=3, config={"relay.ext.mock.options": {"testopt": "value"}} + opt_level=3, + config={"relay.ext.mock.options": {"testopt": "value"}}, + disabled_pass=None, ) From 1506f4e772c98d5199f6ffc39a1193e307b1e38e Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Tue, 13 Apr 2021 12:19:24 +0100 Subject: [PATCH 3/4] added printing of available passes Change-Id: I7a4706c03c0d64cade4977d431bcb25b3708f213 --- python/tvm/driver/tvmc/common.py | 12 ++++++++++-- tests/python/driver/tvmc/test_common.py | 7 ++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 9976a999eb3f..f0ae76a4f4a3 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -345,15 +345,23 @@ def parse_pass_list_str(input_string): ------- list: a list of existing passes. """ + _prefix = "relay._transform." pass_list = input_string.split(",") missing_list = [ p.strip() for p in pass_list - if len(p.strip()) > 0 and tvm.get_global_func("relay._transform." + p.strip(), True) is None + if len(p.strip()) > 0 and tvm.get_global_func(_prefix + p.strip(), True) is None ] if len(missing_list) > 0: + from tvm._ffi import registry + + available_list = [ + n[len(_prefix) :] for n in registry.list_global_func_names() if n.startswith(_prefix) + ] raise argparse.ArgumentTypeError( - "Following passes are not registered within tvm: " + str(missing_list) + "Following passes are not registered within tvm: {}. Available: {}.".format( + ", ".join(missing_list), ", ".join(sorted(available_list)) + ) ) return pass_list diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py index 4a3e709afe23..5cac6a1378a5 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_common.py @@ -25,9 +25,10 @@ def test_common_parse_pass_list_str(): with pytest.raises(argparse.ArgumentTypeError) as ate: tvmc.common.parse_pass_list_str("MyYobaPass,MySuperYobaPass,FuseOps") - assert "MyYobaPass" in str(ate) - assert "MySuperYobaPass" in str(ate) - assert "FuseOps" not in str(ate) + + assert "MyYobaPass" in str(ate.value) + assert "MySuperYobaPass" in str(ate.value) + assert "FuseOps" in str(ate.value) if __name__ == "__main__": From 85cc5f42c4d908b72b6c0002aae3a5a4ec426db6 Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Tue, 13 Apr 2021 13:09:05 +0100 Subject: [PATCH 4/4] C0415(import-outside-toplevel) Change-Id: I33d6f6f86d182de2e21e895ec2dfe9f11f5916dd --- python/tvm/driver/tvmc/common.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index f0ae76a4f4a3..77ba1cb47cc8 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -29,7 +29,7 @@ from tvm import relay from tvm import transform - +from tvm._ffi import registry # pylint: disable=invalid-name logger = logging.getLogger("TVMC") @@ -353,8 +353,6 @@ def parse_pass_list_str(input_string): if len(p.strip()) > 0 and tvm.get_global_func(_prefix + p.strip(), True) is None ] if len(missing_list) > 0: - from tvm._ffi import registry - available_list = [ n[len(_prefix) :] for n in registry.list_global_func_names() if n.startswith(_prefix) ]