From 792210730a51b60c7a1e01245dcb4d5ef96c54cb Mon Sep 17 00:00:00 2001 From: Jorn Tuyls Date: Tue, 2 Mar 2021 06:00:35 -0800 Subject: [PATCH 01/10] Enable Vitis AI target through TVMC & change PassContext API's --- .../install/ubuntu_install_vitis_ai_core.sh | 0 docs/deploy/vitis_ai.rst | 42 +++++----- python/tvm/contrib/target/vitis_ai.py | 83 +++++++++++++------ python/tvm/driver/tvmc/autotuner.py | 2 +- python/tvm/driver/tvmc/common.py | 2 +- python/tvm/driver/tvmc/compiler.py | 2 +- python/tvm/driver/tvmc/composite_target.py | 6 ++ .../tvm/relay/op/contrib/arm_compute_lib.py | 2 +- python/tvm/relay/op/contrib/ethosn.py | 2 +- python/tvm/relay/op/contrib/vitis_ai.py | 68 +++++++++++---- .../contrib/vitis_ai/config_vitis_ai.cc | 34 ++++++++ .../contrib/vitis_ai/vitis_ai_runtime.cc | 1 + .../contrib/test_vitis_ai/infrastructure.py | 7 +- tests/python/driver/tvmc/test_compiler.py | 32 ++++++- .../driver/tvmc/test_composite_target.py | 1 + 15 files changed, 210 insertions(+), 74 deletions(-) mode change 100644 => 100755 docker/install/ubuntu_install_vitis_ai_core.sh diff --git a/docker/install/ubuntu_install_vitis_ai_core.sh b/docker/install/ubuntu_install_vitis_ai_core.sh old mode 100644 new mode 100755 diff --git a/docs/deploy/vitis_ai.rst b/docs/deploy/vitis_ai.rst index 7de8f58ce54f..c36a2a2d96eb 100755 --- a/docs/deploy/vitis_ai.rst +++ b/docs/deploy/vitis_ai.rst @@ -196,7 +196,7 @@ Hardware setup and docker build pip3 install -e . --user Edge (DPUCZDX8G) -^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~ For edge deployment we make use of two systems referred to as host and @@ -451,18 +451,16 @@ TVM. from tvm.contrib.target import vitis_ai from tvm.contrib import utils, graph_runtime from tvm.relay.build_module import bind_params_by_name - from tvm.relay.op.contrib.vitis_ai import annotation + from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai After importing a convolutional neural network model using the usual Relay API's, annotate the Relay expression for the given Vitis-AI DPU target and partition the graph. .. code:: python - - mod["main"] = bind_params_by_name(mod["main"], params) - mod = annotation(mod, params, target) - mod = relay.transform.MergeCompilerRegions()(mod) - mod = relay.transform.PartitionGraph()(mod) + + target='DPUCADX8G' + mod = partition_for_vitis_ai(mod, params, target) Now, we can build the TVM runtime library for executing the model. The TVM target is 'llvm' as the operations that can't be handled by the DPU @@ -473,9 +471,8 @@ build call. .. code:: python tvm_target = 'llvm' - target='DPUCADX8G' - with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options.target': target}): + with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options': {'target': target}}): lib = relay.build(mod, tvm_target, params=params) As one more step before we can accelerate a model with Vitis-AI in TVM @@ -553,7 +550,7 @@ TVM. from tvm.contrib.target import vitis_ai from tvm.contrib import utils, graph_runtime from tvm.relay.build_module import bind_params_by_name - from tvm.relay.op.contrib.vitis_ai import annotation + from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai After importing a convolutional neural network model using the usual Relay API's, annotate the Relay expression for the given Vitis-AI DPU @@ -585,11 +582,10 @@ target and partition the graph. relay.transform.FoldConstant()]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) - + + target = 'DPUCZDX8G-zcu104' # Annotate and partition the Relay expression for the given target - mod = annotation(mod, params, target) - mod = relay.transform.MergeCompilerRegions()(mod) - mod = relay.transform.PartitionGraph()(mod) + mod = partition_for_vitis_ai(mod, params, target) # After partitioning we recommend transforming the remaining convolutions # (that will be executed on CPU, if any) back to NCHW data layout @@ -618,11 +614,13 @@ can be included. .. code:: python tvm_target = 'llvm' - target='DPUCZDX8G-zcu104' export_rt_mod_file = "vitis_ai.rtmod" - - with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options.target': target, - 'relay.ext.vitis_ai.options.export_runtime_module': export_rt_mod_file}): + + build_options = { + 'target': target, + 'export_runtime_module': export_rt_mod_file + } + with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options': build_options}): lib = relay.build(mod, tvm_target, params=params) We will quantize and compile the model for execution on the DPU using on-the-fly @@ -663,9 +661,11 @@ in the TVM build. 'fcompile': contrib.cc.create_shared, 'cc': "/usr/aarch64-linux-gnu/bin/ld" } - - with tvm.transform.PassContext(opt_level=3, - config={'relay.ext.vitis_ai.options.load_runtime_module': export_rt_mod_file}): + + build_options = { + 'load_runtime_module': export_rt_mod_file + } + with tvm.transform.PassContext(opt_level=3, config={'relay.ext.vitis_ai.options': build_options}): lib_arm = relay.build(mod, tvm_target, params=params) lib_dpuv2.export_library('tvm_dpu_arm.so', **lib_kwargs) diff --git a/python/tvm/contrib/target/vitis_ai.py b/python/tvm/contrib/target/vitis_ai.py index f319fd799829..f83c81cca8d2 100644 --- a/python/tvm/contrib/target/vitis_ai.py +++ b/python/tvm/contrib/target/vitis_ai.py @@ -71,37 +71,66 @@ def vitis_ai_compiler(ref): pass_context = tvm.get_global_func("transform.GetCurrentPassContext")() - # The target Vitis-AI accelerator device - target = ( - str(pass_context.config["relay.ext.vitis_ai.options.target"]) - if "relay.ext.vitis_ai.options.target" in pass_context.config + cfg = ( + pass_context.config["relay.ext.vitis_ai.options"] + if "relay.ext.vitis_ai.options" in pass_context.config else None ) - # (Optional configs) The build and work directories to be used by Vitis-AI - vai_build_dir = ( - str(pass_context.config["relay.ext.vitis_ai.options.build_dir"]) - if "relay.ext.vitis_ai.options.build_dir" in pass_context.config - else tvm.contrib.utils.tempdir().relpath("") - ) - vai_work_dir = ( - str(pass_context.config["relay.ext.vitis_ai.options.work_dir"]) - if "relay.ext.vitis_ai.options.work_dir" in pass_context.config - else tvm.contrib.utils.tempdir().relpath("") - ) + # Backward compatibility with old pass context configs + if cfg is None: + warnings.warn( + "You are using a deprecated way of passing build configs (e.g." + " `relay.ext.vitis_ai.options.target`). Check out the Vitis AI " + " documentation here: https://tvm.apache.org/docs/deploy/vitis_ai.html" + " to switch to recommended way for passing build configs." + ) - # (Optional configs) Export and load PyXIR runtime module to file if provided. This is used to - # compile and quantize a model on the host and deploy it at the edge - export_runtime_module = ( - str(pass_context.config["relay.ext.vitis_ai.options.export_runtime_module"]) - if "relay.ext.vitis_ai.options.export_runtime_module" in pass_context.config - else "" - ) - load_runtime_module = ( - str(pass_context.config["relay.ext.vitis_ai.options.load_runtime_module"]) - if "relay.ext.vitis_ai.options.load_runtime_module" in pass_context.config - else "" - ) + # The target Vitis-AI accelerator device + target = ( + str(pass_context.config["relay.ext.vitis_ai.options.target"]) + if "relay.ext.vitis_ai.options.target" in pass_context.config + else None + ) + + # (Optional configs) The build and work directories to be used by Vitis-AI + vai_build_dir = ( + str(pass_context.config["relay.ext.vitis_ai.options.build_dir"]) + if "relay.ext.vitis_ai.options.build_dir" in pass_context.config + else tvm.contrib.utils.tempdir().relpath("") + ) + vai_work_dir = ( + str(pass_context.config["relay.ext.vitis_ai.options.work_dir"]) + if "relay.ext.vitis_ai.options.work_dir" in pass_context.config + else tvm.contrib.utils.tempdir().relpath("") + ) + + # (Optional configs) Export and load PyXIR runtime module to file if provided. This is used to + # compile and quantize a model on the host and deploy it at the edge + export_runtime_module = ( + str(pass_context.config["relay.ext.vitis_ai.options.export_runtime_module"]) + if "relay.ext.vitis_ai.options.export_runtime_module" in pass_context.config + else "" + ) + load_runtime_module = ( + str(pass_context.config["relay.ext.vitis_ai.options.load_runtime_module"]) + if "relay.ext.vitis_ai.options.load_runtime_module" in pass_context.config + else "" + ) + else: + target = cfg.target if cfg.target else None + # (Optional configs) The build and work directories to be used by Vitis AI + vai_build_dir = ( + cfg.build_dir if cfg.build_dir != "" else tvm.contrib.utils.tempdir().relpath("") + ) + + # (Optional configs) Export and load PyXIR runtime module to file if provided. This is used to + # compile and quantize a model on the host and deploy it at the edge + vai_work_dir = ( + cfg.work_dir if cfg.work_dir != "" else tvm.contrib.utils.tempdir().relpath("") + ) + export_runtime_module = cfg.export_runtime_module + load_runtime_module = cfg.load_runtime_module # Config checks if load_runtime_module and target is not None: diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 187b7c5d2a31..3dfe7b11ed02 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -247,7 +247,7 @@ def drive_tune(args): for codegen_from_cli in extra_targets: codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) partition_function = codegen["pass_pipeline"] - mod = partition_function(mod, params) + mod = partition_function(mod, params, **codegen_from_cli["opts"]) # min_repeat_ms should be: # a. the value provided by the user, if any, or diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 71bf42ae1e5c..034f252d8235 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -133,7 +133,7 @@ def tokenize_target(target): target_pattern = ( r"(\-{0,2}[\w\-]+\=?" - r"(?:[\w\+\-]+(?:,[\w\+\-])*|[\'][\w\+\-,\s]+[\']|[\"][\w\+\-,\s]+[\"])*|,)" + r"(?:[\w\+\-.]+(?:,[\w\+\-])*|[\'][\w\+\-,\s]+[\']|[\"][\w\+\-,\s]+[\"])*|,)" ) return re.findall(target_pattern, target) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index fc1805ee0ab4..661a8de1e12a 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -196,7 +196,7 @@ def compile_model( for codegen_from_cli in extra_targets: codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) partition_function = codegen["pass_pipeline"] - mod = partition_function(mod, params) + mod = partition_function(mod, params, **codegen_from_cli["opts"]) if codegen["config_key"] is not None: config[codegen["config_key"]] = codegen_from_cli["opts"] diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index 0a2592685646..a0c4d7e00f4e 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -21,6 +21,8 @@ from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib from tvm.relay.op.contrib.ethosn import partition_for_ethosn +from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai +from tvm.contrib.target import vitis_ai from .common import TVMCException @@ -40,6 +42,10 @@ "config_key": "relay.ext.ethos-n.options", "pass_pipeline": partition_for_ethosn, }, + "vitis-ai": { + "config_key": "relay.ext.vitis_ai.options", + "pass_pipeline": partition_for_vitis_ai, + }, } diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 139f25fef4fd..db61f2fe263e 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -43,7 +43,7 @@ def is_arm_compute_runtime_enabled(): return False -def partition_for_arm_compute_lib(mod, params=None): +def partition_for_arm_compute_lib(mod, params=None, **opts): """Partition the graph greedily offloading supported operators to Arm Compute Library. diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index 478a1ec46f26..2c63d63a36ef 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -46,7 +46,7 @@ def ethosn_available(): return Available.SW_AND_HW if hw else Available.SW_ONLY -def partition_for_ethosn(mod, params=None): +def partition_for_ethosn(mod, params=None, **opts): """Partition the graph greedily offloading supported operators to Arm Ethos-N NPU. diff --git a/python/tvm/relay/op/contrib/vitis_ai.py b/python/tvm/relay/op/contrib/vitis_ai.py index aaa9f99e61ed..679f3bb42152 100644 --- a/python/tvm/relay/op/contrib/vitis_ai.py +++ b/python/tvm/relay/op/contrib/vitis_ai.py @@ -24,8 +24,9 @@ from tvm import relay import tvm._ffi -from tvm.relay.expr import Tuple, TupleGetItem from tvm.relay import transform +from tvm.relay.expr import Tuple, TupleGetItem +from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.annotation import compiler_begin, compiler_end @@ -33,9 +34,10 @@ class VitisAIAnnotationPass: """Responsible for annotating Relay expressions for Vitis-AI DPU accelerators""" - def __init__(self, compiler, relay_ids): + def __init__(self, compiler, target, params): self.compiler = compiler - self.relay_ids = relay_ids + self.target = target + self.params = params def transform_function(self, func, mod, ctx): """Transform function for annotating Relay module""" @@ -80,25 +82,61 @@ def visit_call(self, call): else: return super().visit_call(call) + xgraph = pyxir.frontend.tvm.from_relay(mod, self.params, postprocessing=None) + xgraph = pyxir.partition(xgraph, targets=[self.target]) + + layers = xgraph.get_layers() + relay_ids = [ + list(np.array(layer.attrs["relay_id"]).flatten()) + for layer in layers + if layer.target == self.target + ] + self.relay_ids = [item for sublist in relay_ids for item in sublist] + return Annotator().visit(func) def annotation(mod, params, target): - """Annotate Relay expression for Vitis-AI DPU accelerators""" + """Annotate Relay expression for offloading operators to Vitis AI DPU accelerators + NOTE: This function does the same as the next one (`partition_for_vitis_ai`) but is + still here for backward compatibility""" # We need type information for supporting models that contain operations that don't # have a Relay to XLayer translation mod = relay.transform.InferType()(mod) + mod = VitisAIAnnotationPass("vitis_ai", target, params)(mod) + return mod - xgraph = pyxir.frontend.tvm.from_relay(mod, params, postprocessing=None) - xgraph = pyxir.partition(xgraph, targets=[target]) - layers = xgraph.get_layers() - relay_ids = [ - list(np.array(layer.attrs["relay_id"]).flatten()) - for layer in layers - if layer.target == target - ] - relay_ids_flatten = [item for sublist in relay_ids for item in sublist] - mod = VitisAIAnnotationPass("vitis_ai", relay_ids_flatten)(mod) +def partition_for_vitis_ai(mod, params=None, target=None, **opts): + """Partition the Relay expression for offloading operators to Vitis AI DPU - return mod + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + target : str + The DPU identifier (e.g. DPUCZDX8G-zcu104, DPUCADX8G) + + Returns + ------- + ret : annotated and partitioned module. + """ + + if target is None: + raise ValueError("Please pass Vitis AI DPU target to partitioning function") + + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + VitisAIAnnotationPass("vitis_ai", target, params), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ] + ) + + return seq(mod) diff --git a/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc b/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc index f74b5306c5f4..78348910280e 100644 --- a/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc +++ b/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc @@ -29,6 +29,40 @@ namespace relay { namespace contrib { namespace vitis_ai { +/*! \brief Attributes to store the compiler options for Vitis AI */ +struct VitisAICompilerConfigNode : public tvm::AttrsNode { + String target; + String build_dir; + String work_dir; + String export_runtime_module; + String load_runtime_module; + TVM_DECLARE_ATTRS(VitisAICompilerConfigNode, "ext.attrs.VitisAICompilerConfigNode") { + TVM_ATTR_FIELD(target).describe("Vitis AI DPU target name").set_default(""); + TVM_ATTR_FIELD(build_dir) + .describe("Build directory to be used (optional, debug)") + .set_default(""); + TVM_ATTR_FIELD(work_dir) + .describe("Work directory to be used (optional, debug)") + .set_default(""); + TVM_ATTR_FIELD(export_runtime_module) + .describe("Export the Vitis AI runtime module to this file") + .set_default(""); + TVM_ATTR_FIELD(load_runtime_module) + .describe("Load the Vitis AI runtime module to this file") + .set_default(""); + } +}; + +class VitisAICompilerConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VitisAICompilerConfig, Attrs, + VitisAICompilerConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(VitisAICompilerConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.vitis_ai.options", VitisAICompilerConfig); + +// Following config options are here for backward compatibility (deprecated API's) /*! \brief The target Vitis-AI accelerator device */ TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.vitis_ai.options.target", String); /*! \brief (Optional config) The build directory to be used by Vitis-AI */ diff --git a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc index 37dc767d31af..fa1b3389bfeb 100755 --- a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc +++ b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc @@ -66,6 +66,7 @@ VitisAIRuntime::VitisAIRuntime(const std::string& symbol_name, const std::string pyxir::RunOptionsHolder run_options(new pyxir::runtime::RunOptions()); run_options->on_the_fly_quantization = true; run_options->build_dir = build_dir; + run_options->export_runtime_module_path = export_rt_mod_path_; if (!work_dir.empty()) run_options->work_dir = work_dir; rt_mod_ = pyxir::build_rt(xgraph, target, in_tensor_names_, out_tensor_names_, "vai", run_options); diff --git a/tests/python/contrib/test_vitis_ai/infrastructure.py b/tests/python/contrib/test_vitis_ai/infrastructure.py index df7836a37647..903a1a443010 100644 --- a/tests/python/contrib/test_vitis_ai/infrastructure.py +++ b/tests/python/contrib/test_vitis_ai/infrastructure.py @@ -31,7 +31,7 @@ from tvm import relay from tvm import runtime from tvm.relay import transform -from tvm.relay.op.contrib.vitis_ai import annotation +from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai from tvm.relay.build_module import bind_params_by_name from tvm.contrib.target import vitis_ai from tvm.contrib import graph_runtime @@ -84,10 +84,7 @@ def build_module( opt_level=3, config={"relay.ext.vitis_ai.options.target": dpu_target} ): if enable_vitis_ai: - mod["main"] = bind_params_by_name(mod["main"], params) - mod = annotation(mod, params, dpu_target) - mod = transform.MergeCompilerRegions()(mod) - mod = transform.PartitionGraph()(mod) + mod = partition_for_vitis_ai(mod, params, dpu_target) tvm_op_count = get_cpu_op_count(mod) assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format( tvm_op_count, tvm_ops diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index ae859298facd..43959256d3fa 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -17,6 +17,7 @@ import argparse import os import shutil +import importlib from os import path from unittest import mock @@ -29,6 +30,15 @@ from tvm.driver import tvmc +def vitis_ai_available(): + """Return whether Vitis AI tools are available""" + pyxir_spec = importlib.util.find_spec("pyxir") + if not tvm.get_global_func("tvm.vitis_ai_runtime.from_xgraph", True) or pyxir_spec is None: + print("Skip because Vitis AI tools are not available") + return False + return True + + def test_save_dumps(tmpdir_factory): tmpdir = tmpdir_factory.mktemp("data") dump_formats = {"relay": "fake relay", "ll": "fake llvm", "asm": "fake asm"} @@ -208,6 +218,26 @@ def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant assert type(dumps) is dict +@pytest.mark.skipif( + not vitis_ai_available(), + reason="--target=vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", +) +def test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + + graph, lib, params, dumps = tvmc.compiler.compile_model( + tflite_mobilenet_v1_1_quant, + target="vitis-ai -target=DPUCZDX8G-zcu104 -export_runtime_module=vitis_ai.rtmod, llvm", + dump_code="relay", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + + @mock.patch("tvm.relay.build") @mock.patch("tvm.driver.tvmc.composite_target.get_codegen_by_target") @mock.patch("tvm.driver.tvmc.frontends.load_model") @@ -215,7 +245,7 @@ def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_relay): mock_codegen = {} mock_codegen["config_key"] = "relay.ext.mock.options" - mock_codegen["pass_pipeline"] = lambda *args: None + mock_codegen["pass_pipeline"] = lambda *args, **kwargs: None mock_fe.return_value = (None, None) mock_ct.return_value = mock_codegen diff --git a/tests/python/driver/tvmc/test_composite_target.py b/tests/python/driver/tvmc/test_composite_target.py index cef8b117d989..0a0b45eeb970 100644 --- a/tests/python/driver/tvmc/test_composite_target.py +++ b/tests/python/driver/tvmc/test_composite_target.py @@ -34,6 +34,7 @@ def test_get_codegen_names(): names = tvmc.composite_target.get_codegen_names() assert "ethos-n77" in names + assert "vitis-ai" in names assert len(names) > 0 From 84fa50d2f28b1e64272a2471e35b928b74e50a48 Mon Sep 17 00:00:00 2001 From: Jorn Tuyls Date: Fri, 5 Mar 2021 18:16:10 +0000 Subject: [PATCH 02/10] Update python/tvm/contrib/target/vitis_ai.py Co-authored-by: Cody Yu --- python/tvm/contrib/target/vitis_ai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/target/vitis_ai.py b/python/tvm/contrib/target/vitis_ai.py index f83c81cca8d2..dbcbfff32e30 100644 --- a/python/tvm/contrib/target/vitis_ai.py +++ b/python/tvm/contrib/target/vitis_ai.py @@ -121,7 +121,7 @@ def vitis_ai_compiler(ref): target = cfg.target if cfg.target else None # (Optional configs) The build and work directories to be used by Vitis AI vai_build_dir = ( - cfg.build_dir if cfg.build_dir != "" else tvm.contrib.utils.tempdir().relpath("") + cfg.build_dir if cfg.build_dir else tvm.contrib.utils.tempdir().relpath("") ) # (Optional configs) Export and load PyXIR runtime module to file if provided. This is used to From 7e38ce0bc27ca9c32a8f5568f608062f76c1ea49 Mon Sep 17 00:00:00 2001 From: Jorn Tuyls Date: Fri, 5 Mar 2021 18:16:18 +0000 Subject: [PATCH 03/10] Update python/tvm/contrib/target/vitis_ai.py Co-authored-by: Cody Yu --- python/tvm/contrib/target/vitis_ai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/target/vitis_ai.py b/python/tvm/contrib/target/vitis_ai.py index dbcbfff32e30..ddb77e2aece6 100644 --- a/python/tvm/contrib/target/vitis_ai.py +++ b/python/tvm/contrib/target/vitis_ai.py @@ -127,7 +127,7 @@ def vitis_ai_compiler(ref): # (Optional configs) Export and load PyXIR runtime module to file if provided. This is used to # compile and quantize a model on the host and deploy it at the edge vai_work_dir = ( - cfg.work_dir if cfg.work_dir != "" else tvm.contrib.utils.tempdir().relpath("") + cfg.work_dir if cfg.work_dir else tvm.contrib.utils.tempdir().relpath("") ) export_runtime_module = cfg.export_runtime_module load_runtime_module = cfg.load_runtime_module From 188e53bb795edf03612c3c894af232f78602b6e2 Mon Sep 17 00:00:00 2001 From: Jorn Date: Sat, 6 Mar 2021 08:35:19 -0800 Subject: [PATCH 04/10] Change Vitis AI API to & address comments & fix linter issues --- docs/deploy/vitis_ai.rst | 48 +++++++++---------- python/tvm/contrib/target/vitis_ai.py | 40 ++++++++-------- python/tvm/driver/tvmc/composite_target.py | 2 +- python/tvm/relay/op/contrib/vitis_ai.py | 30 +++++++----- .../contrib/vitis_ai/config_vitis_ai.cc | 4 +- .../contrib/vitis_ai/vitis_ai_runtime.cc | 12 ++--- .../contrib/vitis_ai/vitis_ai_runtime.h | 4 +- tests/python/driver/tvmc/test_compiler.py | 3 +- 8 files changed, 74 insertions(+), 69 deletions(-) diff --git a/docs/deploy/vitis_ai.rst b/docs/deploy/vitis_ai.rst index c36a2a2d96eb..c4e9e3efab19 100755 --- a/docs/deploy/vitis_ai.rst +++ b/docs/deploy/vitis_ai.rst @@ -435,8 +435,8 @@ Cloud usage This section shows how to accelerate a convolutional neural network model in TVM with Vitis-AI on the cloud. -To be able to target the Vitis-AI cloud DPUCADX8G target we first have -to import the target in PyXIR. This PyXIR package is the interface being +To be able to target the Vitis-AI cloud DPUCADX8G we first have +to import the DPU target in PyXIR. This PyXIR package is the interface being used by TVM to integrate with the Vitis-AI stack. Additionaly, import the typical TVM and Relay modules and the Vitis-AI contrib module inside TVM. @@ -459,21 +459,21 @@ target and partition the graph. .. code:: python - target='DPUCADX8G' - mod = partition_for_vitis_ai(mod, params, target) + dpu = 'DPUCADX8G' + mod = partition_for_vitis_ai(mod, params, dpu) Now, we can build the TVM runtime library for executing the model. The TVM target is 'llvm' as the operations that can't be handled by the DPU -are executed on the CPU. The Vitis-AI target is DPUCADX8G as we are -targeting the cloud DPU and this target is passed as a config to the TVM +are executed on the CPU. The Vitis-AI DPU is DPUCADX8G as we are +targeting the cloud DPU and this DPU indetifier is passed as a config to the TVM build call. .. code:: python - tvm_target = 'llvm' + target = 'llvm' - with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options': {'target': target}}): - lib = relay.build(mod, tvm_target, params=params) + with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options': {'dpu': dpu}}): + lib = relay.build(mod, target, params=params) As one more step before we can accelerate a model with Vitis-AI in TVM we have to quantize and compile the model for execution on the DPU. We @@ -534,8 +534,8 @@ A complete ResNet 18 example can be found `here `__. Additionally, we +DPU's, see `edge DPU's info <#edge-requirements>`__. Additionally, we provide the 'export_runtime_module' config that points to a file to which we can export the Vitis-AI runtime module. We have to do this because we will first be compiling and quantizing the model on the host machine before building @@ -613,15 +613,15 @@ can be included. .. code:: python - tvm_target = 'llvm' + target = 'llvm' export_rt_mod_file = "vitis_ai.rtmod" build_options = { - 'target': target, + 'dpu': dpu, 'export_runtime_module': export_rt_mod_file } with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options': build_options}): - lib = relay.build(mod, tvm_target, params=params) + lib = relay.build(mod, target, params=params) We will quantize and compile the model for execution on the DPU using on-the-fly quantization on the host machine. This makes use of TVM inference calls @@ -656,7 +656,7 @@ in the TVM build. .. code:: python # Export lib for aarch64 target - tvm_target = tvm.target.arm_cpu('ultra96') + target = tvm.target.arm_cpu('ultra96') lib_kwargs = { 'fcompile': contrib.cc.create_shared, 'cc': "/usr/aarch64-linux-gnu/bin/ld" @@ -666,7 +666,7 @@ in the TVM build. 'load_runtime_module': export_rt_mod_file } with tvm.transform.PassContext(opt_level=3, config={'relay.ext.vitis_ai.options': build_options}): - lib_arm = relay.build(mod, tvm_target, params=params) + lib_arm = relay.build(mod, target, params=params) lib_dpuv2.export_library('tvm_dpu_arm.so', **lib_kwargs) @@ -688,7 +688,7 @@ as root (execute ``su`` in terminal to log into root). You will see a warning about the 'cpu-tf' runtime not being found. This warning is expected on the board and can be ignored. Note also that you **shouldn't** import the - PyXIR targets in the run script (``import pyxir.contrib.target.DPUCZDX8G``). + PyXIR DPU targets in the run script (``import pyxir.contrib.target.DPUCZDX8G``). .. code:: python diff --git a/python/tvm/contrib/target/vitis_ai.py b/python/tvm/contrib/target/vitis_ai.py index ddb77e2aece6..ff1845c78194 100644 --- a/python/tvm/contrib/target/vitis_ai.py +++ b/python/tvm/contrib/target/vitis_ai.py @@ -36,12 +36,12 @@ def __init__(self, model_name, function): self.function = function self.params = {} - def convert_pyxir(self, target): + def convert_pyxir(self, dpu_target): """Convert Relay expression to PyXIR XGraph""" xgraph = pyxir.frontend.tvm.from_relay( self.function, params=self.params, postprocessing=None ) - xgraph = pyxir.partition(xgraph, targets=[target]) + xgraph = pyxir.partition(xgraph, targets=[dpu_target]) return xgraph def get_output_names(self): @@ -87,7 +87,7 @@ def vitis_ai_compiler(ref): ) # The target Vitis-AI accelerator device - target = ( + dpu_target = ( str(pass_context.config["relay.ext.vitis_ai.options.target"]) if "relay.ext.vitis_ai.options.target" in pass_context.config else None @@ -105,8 +105,8 @@ def vitis_ai_compiler(ref): else tvm.contrib.utils.tempdir().relpath("") ) - # (Optional configs) Export and load PyXIR runtime module to file if provided. This is used to - # compile and quantize a model on the host and deploy it at the edge + # (Optional configs) Export and load PyXIR runtime module to file if provided. This is + # used to compile and quantize a model on the host and deploy it at the edge export_runtime_module = ( str(pass_context.config["relay.ext.vitis_ai.options.export_runtime_module"]) if "relay.ext.vitis_ai.options.export_runtime_module" in pass_context.config @@ -118,26 +118,22 @@ def vitis_ai_compiler(ref): else "" ) else: - target = cfg.target if cfg.target else None + dpu_target = cfg.dpu if cfg.dpu else None # (Optional configs) The build and work directories to be used by Vitis AI - vai_build_dir = ( - cfg.build_dir if cfg.build_dir else tvm.contrib.utils.tempdir().relpath("") - ) + vai_build_dir = cfg.build_dir if cfg.build_dir else tvm.contrib.utils.tempdir().relpath("") - # (Optional configs) Export and load PyXIR runtime module to file if provided. This is used to - # compile and quantize a model on the host and deploy it at the edge - vai_work_dir = ( - cfg.work_dir if cfg.work_dir else tvm.contrib.utils.tempdir().relpath("") - ) + # (Optional configs) Export and load PyXIR runtime module to file if provided. This is + # used to compile and quantize a model on the host and deploy it at the edge + vai_work_dir = cfg.work_dir if cfg.work_dir else tvm.contrib.utils.tempdir().relpath("") export_runtime_module = cfg.export_runtime_module load_runtime_module = cfg.load_runtime_module # Config checks - if load_runtime_module and target is not None: + if load_runtime_module and dpu_target is not None: warnings.warn( - "Both `load_runtime_module` and `target` configs were specified." + "Both `load_runtime_module` and `dpu` configs were specified." " The `load_runtime_module` points to a prebuilt runtime module with" - " an internal target so the `target` config will be ignored" + " an internal DPU target so the `dpu` config will be ignored" ) if load_runtime_module and "relay.ext.vitis_ai.options.build_dir" in pass_context.config: warnings.warn( @@ -156,7 +152,7 @@ def vitis_ai_compiler(ref): if load_runtime_module == "": # Convert Relay expression into XGraph and do partitioning inside PyXIR builder = CodegenVitisAI(name, ref) - xgraph = builder.convert_pyxir(target) + xgraph = builder.convert_pyxir(dpu_target) output_relay_ids = builder.get_output_names() layers = xgraph.get_layers() @@ -170,15 +166,17 @@ def vitis_ai_compiler(ref): break if any([name == "unkown_name" for name in out_tensor_names]): raise ValueError( - "During codegeneration the loading of subexpression \ - failed due to output tensor name mismatch in Relay PyXIR interface." + "During codegeneration the loading of subexpression" + " failed due to output tensor name mismatch in Relay PyXIR interface." ) xgraph.meta_attrs["tvm_out_tensors"] = out_tensor_names xgraph_str = pyxir.get_xgraph_str(xgraph) runtime_func = "tvm.vitis_ai_runtime.from_xgraph" fcreate = tvm._ffi.get_global_func(runtime_func) - return fcreate(name, xgraph_str, target, vai_build_dir, vai_work_dir, export_runtime_module) + return fcreate( + name, xgraph_str, dpu_target, vai_build_dir, vai_work_dir, export_runtime_module + ) runtime_func = "tvm.vitis_ai_runtime.from_rt_mod" fcreate = tvm._ffi.get_global_func(runtime_func) diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index a0c4d7e00f4e..831dbca02cce 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -22,7 +22,7 @@ from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib from tvm.relay.op.contrib.ethosn import partition_for_ethosn from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai -from tvm.contrib.target import vitis_ai +from tvm.contrib.target import vitis_ai # pylint: disable=unused-import from .common import TVMCException diff --git a/python/tvm/relay/op/contrib/vitis_ai.py b/python/tvm/relay/op/contrib/vitis_ai.py index 679f3bb42152..a305be5e1e70 100644 --- a/python/tvm/relay/op/contrib/vitis_ai.py +++ b/python/tvm/relay/op/contrib/vitis_ai.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, unused-argument, no-else-return, E1102 """Vitis-AI codegen annotation of supported operators""" +import warnings import numpy as np import pyxir @@ -34,9 +35,9 @@ class VitisAIAnnotationPass: """Responsible for annotating Relay expressions for Vitis-AI DPU accelerators""" - def __init__(self, compiler, target, params): + def __init__(self, compiler, dpu_target, params): self.compiler = compiler - self.target = target + self.dpu_target = dpu_target self.params = params def transform_function(self, func, mod, ctx): @@ -83,13 +84,13 @@ def visit_call(self, call): return super().visit_call(call) xgraph = pyxir.frontend.tvm.from_relay(mod, self.params, postprocessing=None) - xgraph = pyxir.partition(xgraph, targets=[self.target]) + xgraph = pyxir.partition(xgraph, targets=[self.dpu_target]) layers = xgraph.get_layers() relay_ids = [ list(np.array(layer.attrs["relay_id"]).flatten()) for layer in layers - if layer.target == self.target + if layer.target == self.dpu_target ] self.relay_ids = [item for sublist in relay_ids for item in sublist] @@ -97,17 +98,24 @@ def visit_call(self, call): def annotation(mod, params, target): - """Annotate Relay expression for offloading operators to Vitis AI DPU accelerators + """DEPRECATED + + Annotate Relay expression for offloading operators to Vitis AI DPU accelerators NOTE: This function does the same as the next one (`partition_for_vitis_ai`) but is still here for backward compatibility""" # We need type information for supporting models that contain operations that don't # have a Relay to XLayer translation + warnings.warn( + "tvm.relay.op.contrib.vitis_ai.annotation() is being deprecated." + " Please use tvm.relay.op.contrib.vitis_ai.partition_for_vitis_ai() instead. " + " Check out https://tvm.apache.org/docs/deploy/vitis_ai.html for documentation. " + ) mod = relay.transform.InferType()(mod) mod = VitisAIAnnotationPass("vitis_ai", target, params)(mod) return mod -def partition_for_vitis_ai(mod, params=None, target=None, **opts): +def partition_for_vitis_ai(mod, params=None, dpu=None, **opts): """Partition the Relay expression for offloading operators to Vitis AI DPU Parameters @@ -116,16 +124,16 @@ def partition_for_vitis_ai(mod, params=None, target=None, **opts): The module to run passes on. params : Optional[Dict[str, NDArray]] Constant input parameters. - target : str + dpu : str The DPU identifier (e.g. DPUCZDX8G-zcu104, DPUCADX8G) Returns ------- - ret : annotated and partitioned module. + ret : Module """ - if target is None: - raise ValueError("Please pass Vitis AI DPU target to partitioning function") + if dpu is None: + raise ValueError("Please pass Vitis AI DPU identifier to the partitioning function") if params: mod["main"] = bind_params_by_name(mod["main"], params) @@ -133,7 +141,7 @@ def partition_for_vitis_ai(mod, params=None, target=None, **opts): seq = tvm.transform.Sequential( [ transform.InferType(), - VitisAIAnnotationPass("vitis_ai", target, params), + VitisAIAnnotationPass("vitis_ai", dpu, params), transform.MergeCompilerRegions(), transform.PartitionGraph(), ] diff --git a/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc b/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc index 78348910280e..5426a2dc1e65 100644 --- a/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc +++ b/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc @@ -31,13 +31,13 @@ namespace vitis_ai { /*! \brief Attributes to store the compiler options for Vitis AI */ struct VitisAICompilerConfigNode : public tvm::AttrsNode { - String target; + String dpu; String build_dir; String work_dir; String export_runtime_module; String load_runtime_module; TVM_DECLARE_ATTRS(VitisAICompilerConfigNode, "ext.attrs.VitisAICompilerConfigNode") { - TVM_ATTR_FIELD(target).describe("Vitis AI DPU target name").set_default(""); + TVM_ATTR_FIELD(dpu).describe("Vitis AI DPU identifier").set_default(""); TVM_ATTR_FIELD(build_dir) .describe("Build directory to be used (optional, debug)") .set_default(""); diff --git a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc index fa1b3389bfeb..f55d87d7cdde 100755 --- a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc +++ b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc @@ -49,7 +49,7 @@ VitisAIRuntime::VitisAIRuntime(const std::string& symbol_name, const Array const_names, const std::string& target, + const Array const_names, const std::string& dpu_target, const std::string& build_dir, const std::string& work_dir, const std::string& export_rt_mod_path) : symbol_name_(symbol_name), @@ -61,7 +61,7 @@ VitisAIRuntime::VitisAIRuntime(const std::string& symbol_name, const std::string in_tensor_names_ = xgraph->get_input_names(); out_tensor_names_ = xgraph->get_meta_attr("tvm_out_tensors").get_strings(); - pyxir::partition(xgraph, std::vector{target}, ""); + pyxir::partition(xgraph, std::vector{dpu_target}, ""); pyxir::RunOptionsHolder run_options(new pyxir::runtime::RunOptions()); run_options->on_the_fly_quantization = true; @@ -69,15 +69,15 @@ VitisAIRuntime::VitisAIRuntime(const std::string& symbol_name, const std::string run_options->export_runtime_module_path = export_rt_mod_path_; if (!work_dir.empty()) run_options->work_dir = work_dir; rt_mod_ = - pyxir::build_rt(xgraph, target, in_tensor_names_, out_tensor_names_, "vai", run_options); + pyxir::build_rt(xgraph, dpu_target, in_tensor_names_, out_tensor_names_, "vai", run_options); } Module VitisAIRuntimeCreate(const std::string& name, const std::string& xgraph_str, - const std::string& target, const std::string& build_dir, + const std::string& dpu_target, const std::string& build_dir, const std::string& work_dir, const std::string& export_rt_mod_path) { Array const_vars; - auto exec = make_object(name, xgraph_str, const_vars, target, build_dir, work_dir, - export_rt_mod_path); + auto exec = make_object(name, xgraph_str, const_vars, dpu_target, build_dir, + work_dir, export_rt_mod_path); return Module(exec); } diff --git a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h index 1092bc0ba27b..cad3b5e5a7ff 100755 --- a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h +++ b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h @@ -62,14 +62,14 @@ class VitisAIRuntime : public ModuleNode { * \param symbol_name The name of the function. * \param xgraph_str serialized XGraph representation * \param const_names The names of each constant in the sub-graph. - * \param target The Vitis-AI device target (e.g. DPUCADX8G, DPUCZDX8G). + * \param dpu_target The Vitis-AI DPU target identifier (e.g. DPUCADX8G, DPUCZDX8G-zcu104). * \param build_dir The directory to be used for Vitis-AI build files. * \param work_dir The directory to be used for Vitis-AI work files. * \param export_rt_mod_path The path to the file to be used for exporting the * PyXIR runtime module. */ VitisAIRuntime(const std::string& symbol_name, const std::string& xgraph_str, - const Array const_names, const std::string& target, + const Array const_names, const std::string& dpu_target, const std::string& build_dir, const std::string& work_dir, const std::string& export_runtime_module_path); diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 43959256d3fa..4578f5eeed77 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -34,7 +34,6 @@ def vitis_ai_available(): """Return whether Vitis AI tools are available""" pyxir_spec = importlib.util.find_spec("pyxir") if not tvm.get_global_func("tvm.vitis_ai_runtime.from_xgraph", True) or pyxir_spec is None: - print("Skip because Vitis AI tools are not available") return False return True @@ -227,7 +226,7 @@ def test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v graph, lib, params, dumps = tvmc.compiler.compile_model( tflite_mobilenet_v1_1_quant, - target="vitis-ai -target=DPUCZDX8G-zcu104 -export_runtime_module=vitis_ai.rtmod, llvm", + target="vitis-ai -dpu=DPUCZDX8G-zcu104 -export_runtime_module=vitis_ai.rtmod, llvm", dump_code="relay", ) From 5cf67bc00c31069d998ec85d4786a8de830ea597 Mon Sep 17 00:00:00 2001 From: Jorn Tuyls Date: Mon, 8 Mar 2021 10:38:50 +0000 Subject: [PATCH 05/10] Update docs/deploy/vitis_ai.rst Co-authored-by: Leandro Nunes --- docs/deploy/vitis_ai.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/deploy/vitis_ai.rst b/docs/deploy/vitis_ai.rst index c4e9e3efab19..5c57fd94c7af 100755 --- a/docs/deploy/vitis_ai.rst +++ b/docs/deploy/vitis_ai.rst @@ -465,7 +465,7 @@ target and partition the graph. Now, we can build the TVM runtime library for executing the model. The TVM target is 'llvm' as the operations that can't be handled by the DPU are executed on the CPU. The Vitis-AI DPU is DPUCADX8G as we are -targeting the cloud DPU and this DPU indetifier is passed as a config to the TVM +targeting the cloud DPU and this DPU identifier is passed as a config to the TVM build call. .. code:: python From ae1af0222d4360c3a362a01b2cf07b4e0b23cec8 Mon Sep 17 00:00:00 2001 From: Jorn Tuyls Date: Sun, 14 Mar 2021 10:01:12 +0000 Subject: [PATCH 06/10] Update docs/deploy/vitis_ai.rst Co-authored-by: Cody Yu --- docs/deploy/vitis_ai.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/deploy/vitis_ai.rst b/docs/deploy/vitis_ai.rst index 5c57fd94c7af..6d4d24aacd5a 100755 --- a/docs/deploy/vitis_ai.rst +++ b/docs/deploy/vitis_ai.rst @@ -196,7 +196,7 @@ Hardware setup and docker build pip3 install -e . --user Edge (DPUCZDX8G) -~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~ For edge deployment we make use of two systems referred to as host and From 563142cd7b930ff6d4fea6d990044feb06cee444 Mon Sep 17 00:00:00 2001 From: Jorn Date: Mon, 15 Mar 2021 04:01:44 -0700 Subject: [PATCH 07/10] Add Vitis AI initiliazation to separate init config in TVMC composite target registry --- python/tvm/contrib/target/vitis_ai_utils.py | 36 +++++++++++++++++ python/tvm/driver/tvmc/composite_target.py | 39 +++++++++++++++---- tests/python/driver/tvmc/test_compiler.py | 10 +---- .../driver/tvmc/test_composite_target.py | 21 +++++++++- 4 files changed, 87 insertions(+), 19 deletions(-) create mode 100644 python/tvm/contrib/target/vitis_ai_utils.py diff --git a/python/tvm/contrib/target/vitis_ai_utils.py b/python/tvm/contrib/target/vitis_ai_utils.py new file mode 100644 index 000000000000..1b2bcaa9dea5 --- /dev/null +++ b/python/tvm/contrib/target/vitis_ai_utils.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Vitis AI target utilities""" + +import importlib + +import tvm + + +def vitis_ai_available(): + """Return whether Vitis AI tools are available""" + pyxir_spec = importlib.util.find_spec("pyxir") + if not tvm.get_global_func("tvm.vitis_ai_runtime.from_xgraph", True) or pyxir_spec is None: + return False + return True + + +def init_for_vitis_ai(): + """Initialization function for the Vitis AI codegen""" + # We need to import the Vitis AI target module to make sure the codegen + # is registered + importlib.import_module("tvm.contrib.target.vitis_ai") diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index 39f6ca5322b7..7fe35b66b3ae 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -23,7 +23,7 @@ from tvm.relay.op.contrib.ethosn import partition_for_ethosn from tvm.relay.op.contrib.bnns import partition_for_bnns from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai -from tvm.contrib.target import vitis_ai # pylint: disable=unused-import +from tvm.contrib.target.vitis_ai_utils import init_for_vitis_ai from .common import TVMCException @@ -31,9 +31,19 @@ # pylint: disable=invalid-name logger = logging.getLogger("TVMC") -# Global dictionary to map targets with the configuration key -# to be used in the PassContext (if any), and a function -# responsible for partitioning to that target. + +# Global dictionary to map targets +# +# Options +# ------- +# config_key : str +# The configuration key to be used in the PassContext (if any). +# pass_pipeline : Callable +# A function to transform a Module before compilation, mainly used +# for partitioning for the target currently. +# init : Callable (optional) +# A function for doing initialization for the target codegen. Will be +# called when the target info gets retrieved REGISTERED_CODEGEN = { "compute-library": { "config_key": None, @@ -49,7 +59,8 @@ }, "vitis-ai": { "config_key": "relay.ext.vitis_ai.options", - "pass_pipeline": partition_for_vitis_ai + "init": init_for_vitis_ai, + "pass_pipeline": partition_for_vitis_ai, }, } @@ -65,15 +76,27 @@ def get_codegen_names(): return list(REGISTERED_CODEGEN.keys()) -def get_codegen_by_target(name): +def get_codegen_by_target(name, call_init_function=True): """Return a codegen entry by name. + Parameters + ---------- + name : str + The name of the target for which the codegen info should be retrieved. + call_init_function : bool + Whether to call the initialization function before returning the codegen + info. This is used in tests to avoid calling into initialization functions + that might fail if packages are not installed. + Returns ------- dict - requested target information + requested target codegen information """ try: - return REGISTERED_CODEGEN[name] + target_info = REGISTERED_CODEGEN[name] + if call_init_function and "init" in target_info: + target_info["init"]() + return target_info except KeyError: raise TVMCException("Composite target %s is not defined in TVMC." % name) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 4578f5eeed77..5f0fa69051c5 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -17,7 +17,6 @@ import argparse import os import shutil -import importlib from os import path from unittest import mock @@ -25,19 +24,12 @@ import tvm +from tvm.contrib.target.vitis_ai_utils import vitis_ai_available from tvm.relay.op.contrib.ethosn import ethosn_available from tvm.driver import tvmc -def vitis_ai_available(): - """Return whether Vitis AI tools are available""" - pyxir_spec = importlib.util.find_spec("pyxir") - if not tvm.get_global_func("tvm.vitis_ai_runtime.from_xgraph", True) or pyxir_spec is None: - return False - return True - - def test_save_dumps(tmpdir_factory): tmpdir = tmpdir_factory.mktemp("data") dump_formats = {"relay": "fake relay", "ll": "fake llvm", "asm": "fake asm"} diff --git a/tests/python/driver/tvmc/test_composite_target.py b/tests/python/driver/tvmc/test_composite_target.py index 0a0b45eeb970..47cd3809f83f 100644 --- a/tests/python/driver/tvmc/test_composite_target.py +++ b/tests/python/driver/tvmc/test_composite_target.py @@ -29,6 +29,8 @@ from tvm.driver.tvmc.common import TVMCException +from tvm.contrib.target.vitis_ai_utils import vitis_ai_available + def test_get_codegen_names(): names = tvmc.composite_target.get_codegen_names() @@ -50,14 +52,29 @@ def test_invalid_codegen(): _ = tvmc.composite_target.get_codegen_by_target("invalid") +@pytest.mark.skipif( + not vitis_ai_available(), + reason="--target=vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", +) +def test_vitis_ai_codegen_init(): + tvmc.composite_target.get_codegen_by_target("vitis-ai") + + def test_all_codegens_contain_pass_pipeline(): for name in tvmc.composite_target.get_codegen_names(): - codegen = tvmc.composite_target.get_codegen_by_target(name) + codegen = tvmc.composite_target.get_codegen_by_target(name, call_init_function=False) assert "pass_pipeline" in codegen, f"{name} does not contain a pass_pipeline" assert isfunction(codegen["pass_pipeline"]) def test_all_pass_pipelines_are_functions(): for name in tvmc.composite_target.get_codegen_names(): - codegen = tvmc.composite_target.get_codegen_by_target(name) + codegen = tvmc.composite_target.get_codegen_by_target(name, call_init_function=False) assert isfunction(codegen["pass_pipeline"]), f"pass_pipeline for {name} is not a function" + + +def test_all_initializations_are_functions(): + for name in tvmc.composite_target.get_codegen_names(): + codegen = tvmc.composite_target.get_codegen_by_target(name, call_init_function=False) + if "init" in codegen: + assert isfunction(codegen["init"]), f"init for {name} is not a function" From 6c12ed2917fec5708469b5af3d4816426b5b0c34 Mon Sep 17 00:00:00 2001 From: Jorn Date: Wed, 17 Mar 2021 01:37:06 -0700 Subject: [PATCH 08/10] Lazy load pyxir package in Vitis AI codegen to avoid hard dependency for TVMC --- python/tvm/contrib/target/vitis_ai.py | 99 +++++++++++++------ python/tvm/contrib/target/vitis_ai_utils.py | 36 ------- python/tvm/driver/tvmc/composite_target.py | 20 ++-- tests/python/driver/tvmc/test_compiler.py | 2 +- .../driver/tvmc/test_composite_target.py | 21 +--- 5 files changed, 77 insertions(+), 101 deletions(-) delete mode 100644 python/tvm/contrib/target/vitis_ai_utils.py diff --git a/python/tvm/contrib/target/vitis_ai.py b/python/tvm/contrib/target/vitis_ai.py index ff1845c78194..837e6604bb4c 100644 --- a/python/tvm/contrib/target/vitis_ai.py +++ b/python/tvm/contrib/target/vitis_ai.py @@ -19,30 +19,86 @@ """Utility to offload (sub-)models to Vitis-AI""" import warnings - -import pyxir -import pyxir.frontend.tvm +import importlib from tvm.relay.expr import Tuple, Call, TupleGetItem import tvm._ffi +# Placeholder for PyXIR module +pyxir = None + + +def vitis_ai_available(): + """Return whether Vitis AI tools are available""" + pyxir_spec = importlib.util.find_spec("pyxir") + if not tvm.get_global_func("tvm.vitis_ai_runtime.from_xgraph", True) or pyxir_spec is None: + return False + return True + class CodegenVitisAI: - """Traverse Relay expression and convert into PyXIR XGraph format""" + """Traverse Relay expression and convert into PyXIR XGraph format + + Parameters + ---------- + function : Function + The Relay function + dpu_target : str + The Vitis AI DPU target identifier + """ + + def __init__(self, function, dpu_target): + global pyxir + try: + if pyxir is None: + pyxir = __import__("pyxir") + __import__("pyxir.frontend.tvm") + except ImportError: + # add "from None" to silence + # "During handling of the above exception, another exception occurred" + raise ImportError( + "The pyxir package is required for the Vitis AI backend. " + "Please install it first. " + "Help: (https://tvm.apache.org/docs/deploy/vitis_ai.html) " + ) from None - def __init__(self, model_name, function): - self.model_name = model_name self.function = function + self.dpu_target = dpu_target self.params = {} - def convert_pyxir(self, dpu_target): - """Convert Relay expression to PyXIR XGraph""" + def build(self): + """ "Convert the Relay expression to a PyXIR XGraph to instantiate + the Vitis AI runtime + + Returns + ------- + xgraph_str : str + Serialized XGraph + """ xgraph = pyxir.frontend.tvm.from_relay( self.function, params=self.params, postprocessing=None ) - xgraph = pyxir.partition(xgraph, targets=[dpu_target]) - return xgraph + xgraph = pyxir.partition(xgraph, targets=[self.dpu_target]) + output_relay_ids = self.get_output_names() + layers = xgraph.get_layers() + + # Get the output tensor names using XGraph and output Relay ids + out_tensor_names = ["unknown_name"] * len(output_relay_ids) + for layer in layers: + if not layer.internal: + for relay_id in layer.attrs["relay_id"]: + if relay_id in output_relay_ids: + out_tensor_names[output_relay_ids.index(relay_id)] = layer.name + break + if any([name == "unkown_name" for name in out_tensor_names]): + raise ValueError( + "During codegeneration the loading of subexpression" + " failed due to output tensor name mismatch in Relay PyXIR interface." + ) + xgraph.meta_attrs["tvm_out_tensors"] = out_tensor_names + xgraph_str = pyxir.get_xgraph_str(xgraph) + return xgraph_str def get_output_names(self): """Get output names from Relay expression""" @@ -66,7 +122,6 @@ def vitis_ai_compiler(ref): """Create a Vitis-AI runtime from the provided Relay expression""" assert isinstance(ref, tvm.relay.function.Function) - out_tensor_names = [] name = str(ref.attrs.global_symbol) pass_context = tvm.get_global_func("transform.GetCurrentPassContext")() @@ -151,26 +206,8 @@ def vitis_ai_compiler(ref): # If load_runtime_module is not set, we will build the PyXIR runtime module from scratch if load_runtime_module == "": # Convert Relay expression into XGraph and do partitioning inside PyXIR - builder = CodegenVitisAI(name, ref) - xgraph = builder.convert_pyxir(dpu_target) - output_relay_ids = builder.get_output_names() - layers = xgraph.get_layers() - - # Get the output tensor names using XGraph and output Relay ids - out_tensor_names = ["unknown_name"] * len(output_relay_ids) - for layer in layers: - if not layer.internal: - for relay_id in layer.attrs["relay_id"]: - if relay_id in output_relay_ids: - out_tensor_names[output_relay_ids.index(relay_id)] = layer.name - break - if any([name == "unkown_name" for name in out_tensor_names]): - raise ValueError( - "During codegeneration the loading of subexpression" - " failed due to output tensor name mismatch in Relay PyXIR interface." - ) - xgraph.meta_attrs["tvm_out_tensors"] = out_tensor_names - xgraph_str = pyxir.get_xgraph_str(xgraph) + codegen = CodegenVitisAI(ref, dpu_target) + xgraph_str = codegen.build() runtime_func = "tvm.vitis_ai_runtime.from_xgraph" fcreate = tvm._ffi.get_global_func(runtime_func) diff --git a/python/tvm/contrib/target/vitis_ai_utils.py b/python/tvm/contrib/target/vitis_ai_utils.py deleted file mode 100644 index 1b2bcaa9dea5..000000000000 --- a/python/tvm/contrib/target/vitis_ai_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Vitis AI target utilities""" - -import importlib - -import tvm - - -def vitis_ai_available(): - """Return whether Vitis AI tools are available""" - pyxir_spec = importlib.util.find_spec("pyxir") - if not tvm.get_global_func("tvm.vitis_ai_runtime.from_xgraph", True) or pyxir_spec is None: - return False - return True - - -def init_for_vitis_ai(): - """Initialization function for the Vitis AI codegen""" - # We need to import the Vitis AI target module to make sure the codegen - # is registered - importlib.import_module("tvm.contrib.target.vitis_ai") diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index 7fe35b66b3ae..ac1a41a0c4a9 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -19,11 +19,14 @@ """ import logging +# Make sure Vitis AI codegen is registered +import tvm.contrib.target.vitis_ai # pylint: disable=unused-import + from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib from tvm.relay.op.contrib.ethosn import partition_for_ethosn from tvm.relay.op.contrib.bnns import partition_for_bnns from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai -from tvm.contrib.target.vitis_ai_utils import init_for_vitis_ai + from .common import TVMCException @@ -41,9 +44,6 @@ # pass_pipeline : Callable # A function to transform a Module before compilation, mainly used # for partitioning for the target currently. -# init : Callable (optional) -# A function for doing initialization for the target codegen. Will be -# called when the target info gets retrieved REGISTERED_CODEGEN = { "compute-library": { "config_key": None, @@ -59,7 +59,6 @@ }, "vitis-ai": { "config_key": "relay.ext.vitis_ai.options", - "init": init_for_vitis_ai, "pass_pipeline": partition_for_vitis_ai, }, } @@ -76,17 +75,13 @@ def get_codegen_names(): return list(REGISTERED_CODEGEN.keys()) -def get_codegen_by_target(name, call_init_function=True): +def get_codegen_by_target(name): """Return a codegen entry by name. Parameters ---------- name : str The name of the target for which the codegen info should be retrieved. - call_init_function : bool - Whether to call the initialization function before returning the codegen - info. This is used in tests to avoid calling into initialization functions - that might fail if packages are not installed. Returns ------- @@ -94,9 +89,6 @@ def get_codegen_by_target(name, call_init_function=True): requested target codegen information """ try: - target_info = REGISTERED_CODEGEN[name] - if call_init_function and "init" in target_info: - target_info["init"]() - return target_info + return REGISTERED_CODEGEN[name] except KeyError: raise TVMCException("Composite target %s is not defined in TVMC." % name) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 5f0fa69051c5..9eb6320d029b 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -24,8 +24,8 @@ import tvm -from tvm.contrib.target.vitis_ai_utils import vitis_ai_available from tvm.relay.op.contrib.ethosn import ethosn_available +from tvm.contrib.target.vitis_ai import vitis_ai_available from tvm.driver import tvmc diff --git a/tests/python/driver/tvmc/test_composite_target.py b/tests/python/driver/tvmc/test_composite_target.py index 47cd3809f83f..0a0b45eeb970 100644 --- a/tests/python/driver/tvmc/test_composite_target.py +++ b/tests/python/driver/tvmc/test_composite_target.py @@ -29,8 +29,6 @@ from tvm.driver.tvmc.common import TVMCException -from tvm.contrib.target.vitis_ai_utils import vitis_ai_available - def test_get_codegen_names(): names = tvmc.composite_target.get_codegen_names() @@ -52,29 +50,14 @@ def test_invalid_codegen(): _ = tvmc.composite_target.get_codegen_by_target("invalid") -@pytest.mark.skipif( - not vitis_ai_available(), - reason="--target=vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", -) -def test_vitis_ai_codegen_init(): - tvmc.composite_target.get_codegen_by_target("vitis-ai") - - def test_all_codegens_contain_pass_pipeline(): for name in tvmc.composite_target.get_codegen_names(): - codegen = tvmc.composite_target.get_codegen_by_target(name, call_init_function=False) + codegen = tvmc.composite_target.get_codegen_by_target(name) assert "pass_pipeline" in codegen, f"{name} does not contain a pass_pipeline" assert isfunction(codegen["pass_pipeline"]) def test_all_pass_pipelines_are_functions(): for name in tvmc.composite_target.get_codegen_names(): - codegen = tvmc.composite_target.get_codegen_by_target(name, call_init_function=False) + codegen = tvmc.composite_target.get_codegen_by_target(name) assert isfunction(codegen["pass_pipeline"]), f"pass_pipeline for {name} is not a function" - - -def test_all_initializations_are_functions(): - for name in tvmc.composite_target.get_codegen_names(): - codegen = tvmc.composite_target.get_codegen_by_target(name, call_init_function=False) - if "init" in codegen: - assert isfunction(codegen["init"]), f"init for {name} is not a function" From fd9b6945327860f844d76f1d73717da8c1b4270e Mon Sep 17 00:00:00 2001 From: Jorn Date: Sat, 3 Apr 2021 07:33:13 -0700 Subject: [PATCH 09/10] Fix TVMC Vitis AI test for compiler.compile_model API change --- tests/python/driver/tvmc/test_compiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 4fa7f772f22d..8cd77b8cde4a 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -219,8 +219,10 @@ def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant def test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") + mod, params = tvmc.load(tflite_mobilenet_v1_1_quant) graph, lib, params, dumps = tvmc.compiler.compile_model( - tflite_mobilenet_v1_1_quant, + mod, + params, target="vitis-ai -dpu=DPUCZDX8G-zcu104 -export_runtime_module=vitis_ai.rtmod, llvm", dump_code="relay", ) From 719a94e2bd9a4050343cc8d8c5d5935fced6d919 Mon Sep 17 00:00:00 2001 From: Jorn Date: Tue, 6 Apr 2021 02:47:09 -0700 Subject: [PATCH 10/10] Lazy load pyxir package in Vitis AI partitioning pass --- python/tvm/relay/op/contrib/vitis_ai.py | 32 +++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/contrib/vitis_ai.py b/python/tvm/relay/op/contrib/vitis_ai.py index a305be5e1e70..0c05c8db7435 100644 --- a/python/tvm/relay/op/contrib/vitis_ai.py +++ b/python/tvm/relay/op/contrib/vitis_ai.py @@ -20,9 +20,6 @@ import warnings import numpy as np -import pyxir -import pyxir.frontend.tvm - from tvm import relay import tvm._ffi from tvm.relay import transform @@ -30,12 +27,39 @@ from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.annotation import compiler_begin, compiler_end +# Placeholder for PyXIR module +pyxir = None + @transform.function_pass(opt_level=0) class VitisAIAnnotationPass: - """Responsible for annotating Relay expressions for Vitis-AI DPU accelerators""" + """Responsible for annotating Relay expressions for Vitis-AI DPU accelerators + + Parameters + ---------- + compiler : str + The compiler name used for annotations (`vitis_ai`). + dpu_target : str + The Vitis AI DPU target identifier. + params : dict + A dictionary containing the module's parameters. + """ def __init__(self, compiler, dpu_target, params): + global pyxir + try: + if pyxir is None: + pyxir = __import__("pyxir") + __import__("pyxir.frontend.tvm") + except ImportError: + # add "from None" to silence + # "During handling of the above exception, another exception occurred" + raise ImportError( + "The pyxir package is required for the Vitis AI backend. " + "Please install it first. " + "Help: (https://tvm.apache.org/docs/deploy/vitis_ai.html) " + ) from None + self.compiler = compiler self.dpu_target = dpu_target self.params = params