From 3ebf0cfb4e9734be43b546471951824c9956d79e Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 2 Mar 2023 20:03:54 +0000 Subject: [PATCH 1/2] [TVMC] Add option to dump TIR code to file Currently dumps the TIR code after the final phase of lowering (phase 3 from https://github.com/apache/tvm/blob/665dd413bc85d14f7836324daf7cc0dd9281c85a/gallery/how_to/extend_tvm/low_level_custom_pass.py#L152) before codegen. This is done by running a pass to capture the TIR module as it is not saved during the build. The result is saved to file similar to the other code dumps. Change-Id: I6da7df8f87ba4d91215adc679d926d4f0ca89019 --- python/tvm/driver/tvmc/compiler.py | 55 ++++++++++++++++------- tests/python/driver/tvmc/test_compiler.py | 13 +++++- 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index eec80820cdb1..2e7934a99255 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -71,7 +71,7 @@ def add_compile_parser(subparsers, _, json_params): "--dump-code", metavar="FORMAT", default="", - help="comma separated list of formats to export the input model, e.g. 'asm,ll,relay'.", + help="comma separated list of formats to export the input model, e.g. 'asm,ll,tir,relay'.", ) parser.add_argument( "--model-format", @@ -254,9 +254,9 @@ def compile_model( output_format : str What format to use when saving the function library. Must be one of "so" or "tar". When compiling for a remote device without a cross compiler, "tar" will likely work better. - dump_code : list, optional + dump_code : list[str], optional Dump the generated code for the specified source types, on - the requested target. + the requested target. Choose from: ["asm", "ll", "tir", "relay"]. target_host : str, optional The target of the host machine if host-side code needs to be generated. @@ -290,7 +290,15 @@ def compile_model( """ mod, params = tvmc_model.mod, tvmc_model.params + if dump_code is None: + dump_code = [] + if not isinstance(dump_code, list): + dump_code = [dump_code] + dumps = {} + config = parse_configs(pass_context_configs) + if "tir" in dump_code: + config, dumps = add_tir_to_dumps(config, dumps) tvm_target, extra_targets = target_from_cli(target, additional_target_options) tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host) @@ -366,20 +374,16 @@ def compile_model( ) # Generate output dump files with sources - if dump_code is None: - dump_code = [] - if not isinstance(dump_code, list): - dump_code = [dump_code] - dumps = {} for source_type in dump_code: - if use_vm: - lib = graph_module.lib + if source_type == "relay": + dumps[source_type] = str(mod) + elif source_type == "tir": + dumps[source_type] = "\n".join(dumps[source_type]) else: - lib = graph_module.get_lib() - # TODO lib.get_source call have inconsistent behavior for unsupported - # formats (@leandron). - source = str(mod) if source_type == "relay" else lib.get_source(source_type) - dumps[source_type] = source + lib = graph_module.lib if use_vm else graph_module.get_lib() + # TODO lib.get_source call have inconsistent behavior for unsupported + # formats (@leandron). + dumps[source_type] = lib.get_source(source_type) # Create a new tvmc model package object from the graph definition. package_path = tvmc_model.export_package( @@ -440,6 +444,27 @@ def build( ) +def add_tir_to_dumps(config, dumps): + """ + Creates a debug pass that dumps TIR functions as a list of strings. + """ + key = "tir" + phase = 3 # final TIR phase before codegen + dumps[key] = [] + + @tvm.tir.transform.prim_func_pass(opt_level=0) + def _dump_tir_pass(tir_func, _, __): + dumps[key].append(str(tir_func)) + return tir_func + + tir_lower_passes = config.get("tir.add_lower_pass", []) + tir_lower_passes.append((phase, _dump_tir_pass)) + if tir_lower_passes: + config["tir.add_lower_pass"] = tir_lower_passes + + return config, dumps + + def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."): """ Serialize dump files to the disk. diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 3a3f297729fd..525f75750ec6 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -40,11 +40,12 @@ def test_save_dumps(tmpdir_factory): tmpdir = tmpdir_factory.mktemp("data") - dump_formats = {"relay": "fake relay", "ll": "fake llvm", "asm": "fake asm"} + dump_formats = {"relay": "fake relay", "tir": "fake tir", "ll": "fake llvm", "asm": "fake asm"} tvmc.compiler.save_dumps("fake_module", dump_formats, dump_root=tmpdir) assert path.exists("{}/{}".format(tmpdir, "fake_module.ll")) assert path.exists("{}/{}".format(tmpdir, "fake_module.asm")) + assert path.exists("{}/{}".format(tmpdir, "fake_module.tir")) assert path.exists("{}/{}".format(tmpdir, "fake_module.relay")) @@ -87,6 +88,16 @@ def test_compile_tflite_module(use_vm, tflite_mobilenet_v1_1_quant): verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict, use_vm=use_vm) +def test_tir_dump(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="tir") + dumps_path = tvmc_package.package_path + ".tir" + assert os.path.exists(dumps_path) + with open(dumps_path) as f: + assert "tir" in f.read() + + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" From a6a29fa2b692d865501d7d91210590ed419e5287 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 3 Mar 2023 15:16:51 +0000 Subject: [PATCH 2/2] add multi dump test case and remove unecessary check Change-Id: I0d1dc761b66c9f41592992580e2ccc0bad8f4769 --- python/tvm/driver/tvmc/compiler.py | 3 +-- tests/python/driver/tvmc/test_compiler.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 2e7934a99255..2c98085d7dfe 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -459,8 +459,7 @@ def _dump_tir_pass(tir_func, _, __): tir_lower_passes = config.get("tir.add_lower_pass", []) tir_lower_passes.append((phase, _dump_tir_pass)) - if tir_lower_passes: - config["tir.add_lower_pass"] = tir_lower_passes + config["tir.add_lower_pass"] = tir_lower_passes return config, dumps diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 525f75750ec6..6bcf19056df3 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -88,7 +88,7 @@ def test_compile_tflite_module(use_vm, tflite_mobilenet_v1_1_quant): verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict, use_vm=use_vm) -def test_tir_dump(tflite_mobilenet_v1_1_quant): +def test_single_tir_dump(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="tir") @@ -98,6 +98,18 @@ def test_tir_dump(tflite_mobilenet_v1_1_quant): assert "tir" in f.read() +def test_code_dumps(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + dump_code = ["asm", "ll", "tir", "relay"] + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code=dump_code) + for ext in dump_code: + dumps_path = tvmc_package.package_path + "." + ext + assert os.path.exists(dumps_path) + with open(dumps_path) as f: + assert len(f.read()) > 0 + + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed"