From 0e05122fa6850f43a007ed680a7e3e8c91c48026 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Mon, 30 Jan 2023 14:25:14 +0530 Subject: [PATCH 01/20] [DOCS][ADRENO] Improved Adreno documentation Unified single documentation for all types of usage with OpenCL as well as CLML backends. Detailed simplified usage (with docker environment and command line tools like tvmc) as well as advanced usage instructions via python based interface. --- docs/how_to/deploy/adreno.rst | 833 ++++++++++++++---- python/tvm/relay/op/contrib/__init__.py | 1 + python/tvm/relay/op/contrib/adreno.py | 85 ++ .../relay/opencl_texture/test_network.py | 24 +- 4 files changed, 755 insertions(+), 188 deletions(-) create mode 100644 python/tvm/relay/op/contrib/adreno.py diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index 7f4616fbf797..bdee7e597d7b 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -15,41 +15,60 @@ specific language governing permissions and limitations under the License. -Deploy to Adreno GPU -======================================= +Deploy to Adreno™ GPU +==================== -**Authors**: Daniil Barinov, Egor Churaev, Andrey Malyshev +**Authors**: Daniil Barinov, Egor Churaev, Andrey Malyshev, Siva Rama Krishna Introduction ------------ -Adreno is a series of graphics processing unit (GPU) semiconductor +Adreno™ is a series of graphics processing unit (GPU) semiconductor intellectual property cores developed by Qualcomm and used in many of their SoCs. -The Adreno GPU accelerates the rendering of complex geometries to +The Adreno™ GPU accelerates the rendering of complex geometries to deliver high-performance graphics and a rich user experience with low power consumption. -This guide will demonstrate :ref:`the benefits of using textures with Adreno`, -how to :ref:`build TVM with OpenCL` (needed by Adreno devices) and TVM RPC -enabled. It will also provide :ref:`example code` to better understand the differences in compiling and deploying models -for Adreno devices. +TVM supports deep learning acceleration on Adreno™ GPU by native OpenCL backend of TVM and +also through OpenCLML backend. Native OpenCL backend of TVM is enhanced to make it +Adreno™ friendly by incorporating texture memory usage and Adreno™ friendly layouts. +OpenCLML is an SDK release by Qualcomm that provides kernel acceleration library +for most of the deep learning operators. -.. _advantages_of_the_textures: +This guide is organized to demonstrate various design aspects of -Advantages of the Textures --------------------------- +- :ref:`OpenCL Backend Ehnahcements` +- :ref:`About OpenCLML` +- :ref:`Build and Deploy` -One of the Adreno's advantages is the clever handling of textures. At + + +.. how to :ref:`build TVM with OpenCL` (needed by Adreno™ devices) and TVM RPC +.. enabled. It will also provide :ref:`example code` to better understand the differences in compiling and deploying models +.. for Adreno™ devices. + + +.. _opencl_enhancements: + +OpenCL Backend Enhancements +--------------------------- + +OpenCL backend of TVM is enhanced to take advantage of Adreno™ specific features like +- Texture memory usage. +- Adreno™ friendly activation layouts. +- Brand new schedules to accelerate with above features. + +One of the Adreno™'s advantages is the clever handling of textures. At the moment, TVM is able to benefit from this by having texture support -for Adreno. The graph below shows the Adreno A5x architecture. +for Adreno™. The graph below shows the Adreno™ A5x architecture. -|High-level overview of the Adreno A5x architecture for OpenCL| +|High-level overview of the Adreno™ A5x architecture for OpenCL| -*Fig. 1 High-level overview of the Adreno A5x architecture for OpenCL* +*Fig. 1 High-level overview of the Adreno™ A5x architecture for OpenCL* -*source:* `OpenCL Optimization and Best Practices for Qualcomm Adreno GPUs `_ +*source:* `OpenCL Optimization and Best Practices for Qualcomm Adreno™ GPUs `_ Reasons of using textures: @@ -65,134 +84,667 @@ Reasons of using textures: Overall, with textures, it is possible to achieve a significant performance boost compared to OpenCL buffer based solutions. -.. _building_tvm_for_adreno: +.. _about_openclml: + +About OpenCLML +-------------- + +OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning operators. +These operators are exposed as an extension "cl_qcom_ml_ops" to standard OpenCL specification. +Please refer `Accelerate your models with our OpenCL ML SDK `_ for more details. + +OpenCLML is integrated into TVM as a `BYOC `_ solution. +OpenCLML operators can use same context and the operatrors can be enqueued on same command queue if native OpenCL. +We took advantage of this to avoid any context switching over heads while fallback to native OpenCL. + + +.. _build_deploy: + +TVM for Adreno™ +--------------- + +This section gives instructions about various ways of building and deploying model +to Adreno™ target. Adreno™ is a remote target which is connected to the host via ADB connection. +Deploying the compiled model here require use some tools on host as well as on target. + +TVM has simplified user friendly command line based tools as well as +developer centric python API interface for various steps like auto tuning, building and deploying. + +TVM compilation process for remote devices has multiple stages listed below. + +**Model import:** +At this stage we import a model from well known frameworks like Tensorflow, PyTorch, ONNX ...etc. +This stage converts the given model into TVM's relay module format. Alternatively one can build a relay module manually +by using TVM's operator inventory too. TVM module generated here is a target independent representation of the graph. + +**Auto Tuning:** +At this stage we tune the TVM generated kernels specific to a target. Auto tuning process requires +target device availability and in case of a remote target like Adreno™ on Android device we use RPC Setup for communication. +Later sections in this guide will detail about RPC Setup for Android device. Auto tuning is not a necessary step for +compilation of a model. It is necessary for acheiving best performance out of TVM generated kernels. + +**Compilation:** +At this stage we compile the model for specific target. Given we auto tuned the module in previous stage, +TVM compilation make use of the tuning log for genetrating best performing kernels. TVM compilation process produces artifacts +containing kernel shared lib, graph definition in json format and parameters binary file in TVM specific format. + +**Deploy (or test run) on Target:** +At this stage we run the TVM compilation output on the target. Deployment is possible from python +environment using RPC Setup and also using TVM's native tool which is native binary cross compiled for Android. +At this stage we can run the compiled model on Android target and unit test output correctness and performance aspects. + +**Aplication Integration:** +This stage is all about integrating TVM compiled model in applications. Here we discuss about +interfacing tvm runtime from Android (cpp native environment or from JNI) for setting input and getting output. + +**Advanced Usage:** +This section advanced user interests like viewing generated source code, altering precision of the module ...etc. + -Building TVM for Adreno ------------------------ +This tutorial covers all the above aspects as part of below sections. -This section gives instructions on how to build the Android part of TVM -with OpenCL and TVM RPC Server in order to deploy models on Adreno. +- :ref:`Development environment` +- :ref:`RPC Setup` +- :ref:`Commandline tools` +- :ref:`Python interface` +- :ref:`Application Integration` +- :ref:`Advanced Usage` -Since the process of building TVM for Adreno is exactly the same as the -process of building TVM for Android, please refer to these instructions: -`TVM RPC -Server `_. +.. _development_environment: + + +Development Environment Setup : Automatic +----------------------------------------- +TVM ships a predefined docker container environment with all prerequisites to get started quickly. +You may also refer to :ref:`Manual Environment Setup` for more control on the dependencies. + +For docker setup the pre requisite is just docker tool availabilty on host. + +Below commands can build a docker image for adreno. + +:: + + ./docker/build.sh ci_adreno + docker tag tvm.ci_adreno ci_adreno + + +Now we can build both host and target utils with below command. + +:: + + ./tests/scripts/ci.py adreno -i + +To build TVM with OpenCLML SDK we need export the OpenCLML SDK as shown below while building + +:: + + export ADRENO_OPENCL= + ./tests/scripts/ci.py adreno -i + +On successful compilation this leaves us into a docker shell. The build leaves two folders + +* build-adreno: The host side TVM compiler build. +* build-adreno-target : Contains the android target components + + * libtvm_runtime.so : TVM runtime library + * tvm_rpc : The rpc runtime environment tool + * rtvm : A native stand alone tool + +While using docker environment the android device is shared with host. Hence, it is required +to have adb version "1.0.41" on the host as the docker used the same version. + +We can check adb devices availability inside docker environment too. + +:: -Since there are many required packages for Android, you can use the official Docker Image to build TVM. -For more information refer to this guide: `Deploy the Pretrained Model on Android `_. + user@ci-adreno-fpeqs:~$ adb devices + List of devices attached + aaaabbbb device + ccccdddd device -**Prerequisites**: Android NDK and Android Debug Bridge must -be installed, the desired device must have OpenCL support and Android part of TVM must be built: +.. _manual_setup: + +Development Environment Setup : Manual +-------------------------------------- + +Manual build process require building of host and target components. + +Below command will configure the build the host compiler + +:: + + mkdir -p build + cd build + cp ../cmake/config.cmake . + + echo set\(USE_OPENCL ON\) >> config.cmake + echo set\(USE_RPC ON\) >> config.cmake + echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake + echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake + echo set\(USE_LLVM ON\) >> config.cmake + +Additionally we can push below config entry to compile with OpenCLML support. + +:: + + export ADRENO_OPENCL= + echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake + +now we can build as shown below + +:: + + cmake .. + make + +Finally we can export python path as + +:: + + export PYTHONPATH=$PWD:/python + python3 -c "import tvm" # Verify tvm python package + + +Now, we can configure and build the target components with below configuration +Target build require Android NDK to be installed. - Read documentation about *Android NDK installation* here: https://developer.android.com/ndk - To get access to adb tools you can see *Android Debug Bridge installation* here: https://developer.android.com/studio/command-line/adb -You can also build the android part of TVM locally. From the root -folder of TVM: :: - mkdir build_android - cd build_android - cmake .. -DUSE_OPENCL=ON -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_NATIVE_API_LEVEL=android-28 -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=ON -DANDROID_STL=c++_static -DUSE_CPP_RPC=ON - make -jN tvm_runtime tvm_rpc + mkdir -p build-adreno + cd build-adreno + cp ../cmake/config.cmake . + echo set\(USE_MICRO OFF\) >> config.cmake + echo set\(USE_OPENCL ON\) >> config.cmake + echo set\(USE_RPC ON\) >> config.cmake + echo set\(USE_CPP_RPC ON\) >> config.cmake + echo set\(USE_CPP_RTVM ON\) >> config.cmake + echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake + echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake + echo set\(USE_KALLOC_ALIGNMENT 32\) >> config.cmake -where **N** is the number of cores available on your *CPU*. + echo set\(ANDROID_ABI arm64-v8a\) >> config.cmake + echo set\(ANDROID_PLATFORM android-28\) >> config.cmake + echo set\(MACHINE_NAME aarch64-linux-gnu\) >> config.cmake -At this stage you have built TVM for Adreno. +Additionally we can push below config to compile with OpenCLML support. -.. _build_and_deploy_model_for_adreno: +:: -Build and deploy model for Adreno ---------------------------------- + export ADRENO_OPENCL= + echo set\(USE_CLML "${ADRENO_OPENCL}"\) >> config.cmake + echo set\(USE_CLML_GRAPH_EXECUTOR "${ADRENO_OPENCL}"\) >> config.cmake -In this section we will focus on target, needed to compile and deploy models for Adreno, demonstrate -the differences in generated kernels with and without textures and, in addition, the -possibility of choosing a different precision for model compilation will -be considered. +For Android target build ANDROID_NDK_HOME is a dependency and we should have the same in the enviromnet variable. +Below commands will build Adreno™ target components -For the complete step-py-step process of compiling and deploying models on -Adreno, including selection of precision, running the inference of the -model, getting the predictions, and measuring the performance please refer to this tutorial: `How To Deploy model on Adreno `_ +:: -|Android deployment pipeline| + cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake" \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DCMAKE_SYSTEM_VERSION=1 \ + -DCMAKE_FIND_ROOT_PATH="${ADRENO_OPENCL}" \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_CXX_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang++" \ + -DCMAKE_C_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" \ + -DMACHINE_NAME="aarch64-linux-gnu" .. -*Fig.2 Deployment pipeline on Adreno devices* + make tvm_runtime tvm_rpc rtvm -The figure above demonstrates a generalized pipeline for deploying and running neural network models on android devices. -As can be seen from the figure, the compiled model has a set_input() and a run() methods, -which *prepare the inputs* for inference and *execute the inference* on the remote device using the Graph Executor runtime module. -Adreno target -~~~~~~~~~~~~~ +.. _rpc_setup: -Normally, when compiling models for Android using OpenCL, the -corresponding target is used +RPC Setup +--------- -.. code:: python +RPC Setup allows remote target access over TCP/IP networking interface. RPC Setup is essential for auto tuning stage as tuning +involves running of auto generated kernels on real device and optimize the same by using machine learning approach. Please refer +`Auto-Tune with Templates and AutoTVM `_ got more details about AutoTVM. + +RPC Setup is also useful to deply the compiled model to a remote device from python interface or ```tvmc``` tool from host device. + +RPC Setup has multiple components as listed below. + +**TVM Tracker:** +TVM tracker is a host side daemon that manages remote devices and serve them to host side applications. Applications +can connect to this tracker and acquire a remote device handle to communicate. + +**TVM RPC:** +TVM RPC is a native application that runs on the remote device (Android in our case) and registers itself to the TVM Tracker +running on the host. + + +Hence, for RPC based setup we will have above components running on host and target device. Below sections explain how to setup the same +manually and also inside docker using automated tools. + +**Automated RPC Setup:** +Here we will explain how to setup RPC in docker environment. + +Below command launches tracker in docker environment, where docker listens on port 9120. + +:: + + ./tests/scripts/ci.py adreno -i # Launch a new shell on the anreno docker + source tests/scripts/setup-adreno-env.sh -e tracker -p 9120 + +Now, the below comand can run TVM RPC on remote android device with id "abcdefgh". + + +:: + + ./tests/scripts/ci.py adreno -i # Launch a new shell on adreno docker. + source tests/scripts/setup-adreno-env.sh -e device -p 9120 -d abcdefgh + + +**Manual RPC Setup:** + +Below command in manual setup starts the tracker on port 9120 + +:: - target="opencl" + python3 -m tvm.exec.rpc_tracker --host "0.0.0.0" --port "9120" -Using Adreno, we want to get all the benefits of textures, so we have to -use the following target to generate texture leveraging kernels +TVM RPC launch on Android device require some environment setup due to Android device is connected via ADB interface and we need to re-route +TCP/IP communication over ADB interface. Below commands will do necessary setup and run tvm_rpc on remote device. + +:: + + # Set android device to use + export ANDROID_SERIAL=abcdefgh + # Create a temporary folder on remote device. + adb shell "mkdir -p /data/local/tmp/tvm_ci" + # Copy tvm_rpc and it's dependency to remote device + adb push build-adreno-target/tvm_rpc /data/local/tmp/tvm_test/tvm_rpc + adb push build-adreno-target/libtvm_runtime.so /data/local/tmp/tvm_test + # Forward port 9120 from target to host + adb reverse tcp:9210 tcp:9120 + # tvm_rpc by default listens on ports starting from 5000 for incoming connections. + # Hence, reroute connections to these ports on host to remore device. + adb forward tcp:5000 tcp:5000 + adb forward tcp:5001 tcp:5001 + adb forward tcp:5002 tcp:5002 + # Finally launch rpc_daemon on remote device with identity key as "android" + adb shell "cd /data/local/tmp/tvm_test; killall -9 tvm_rpc; sleep 2; LD_LIBRARY_PATH=/data/local/tmp/tvm_test/ ./tvm_rpc server --host=0.0.0.0 --port=5000 --port-end=5010 --tracker=127.0.0.1:9120 --key=android" + +Upon successfull running this remote device will be available on tracker which can be queried as below. + +:: + + python3 -m tvm.exec.query_rpc_tracker --port 9120 + Tracker address 127.0.0.1:9120 + Server List + ------------------------------ + server-address key + ------------------------------ + 127.0.0.1:5000 server:android + ------------------------------ + + Queue Status + ------------------------------- + key total free pending + ------------------------------- + android 1 1 0 + ------------------------------- + +This concludes RPC Setup and we have rpc-tracker available on host 127.0.0.1 (rpc-tracker) and port 9120 (rpc-port). + + +.. _commandline_interface: + +Commandline Tools +----------------- + +Here we describe entire compilation process using command line tools. TVM has command line utility "tvmc" to perform +model import, auto tuning, compilation and deply over rpc. "tvmc" has many options to explore and try. + +**Model Import & Tuning:** +Use the below command to import a model from any framework and auto tune the same. +Here we use a model from Keras and it uses RPC setup for tuning and finally generates tuning log file +"keras-resnet50.log". + +:: + + python3 -m tvm.driver.tvmc tune --target="opencl -device=adreno" \ + --target-host="llvm -mtriple=aarch64-linux-gnu" \ + resnet50.h5 -o \ + keras-resnet50.log \ + --early-stopping 0 --repeat 30 --rpc-key android \ + --rpc-tracker 127.0.0.1:9120 --trials 1024 \ + --tuning-records keras-resnet50-records.log --tuner xgb + +**Model Compilation:** + +Use below command for compiling the model and produce TVM compiler outputs. + +:: + + python3 -m tvm.driver.tvmc compile \ + --cross-compiler ${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang \ + --target="opencl, llvm" --target-llvm-mtriple aarch64-linux-gnu --target-opencl-device adreno \ + --tuning-records keras-resnet50.log -o keras-resnet50.tar resnet50.h5 + +While enabled OpenCLML offloading we nee dto add target "clml" as shown below. Tuning log is valid for OpenCLML offloading also +as the OpenCL path is fallback option for any operator didn't go through OpenCLML path. The tuning log will be used for such operators. + +:: + + python3 -m tvm.driver.tvmc compile \ + --cross-compiler ${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang \ + --target="opencl, clml, llvm" --target-llvm-mtriple aarch64-linux-gnu --target-opencl-device adreno \ + --tuning-records keras-resnet50.log -o keras-resnet50.tar resnet50.h5 + +On success ful compilation above commands produce "keras-resnet50.tar". It is a compressed archive with kernel shared lib, graph json and params binary. + +**Deploy & Run on Target:** + +Running the compiled model on Android target is possible in RPC way as well as native deployment. + +We can use below tvmc command to deploy on remore target via RPC based setup. + +:: + + python3 -m tvm.driver.tvmc run --device="cl" keras-resnet50.tar \ + --rpc-key android --rpc-tracker 127.0.0.1:9120 --print-time + +tvmc based run has more option to initialize the input in various modes line fill, random ..etc. + + +TVM also supports "rtvm" tool to run the model narivelu on ADB shell. The build process produced this tool under build-adreno-target. +Please refer to `rtvm `_ for more details about this tool. + + +.. _python_interface: + +This section explains importing, auto tuning, compiling and running a model using python interface.\ +TVM has a high level interface through tvmc abstraction as well as relay api. We will discuss about both of these in details. + +Unlike command line interface python interface starts with model importing. Model importing converts the models from any framework +to a relay module. Relay module will be used across the auto tuning, compilation stages. + +**TVMC Interface:** + +TVMC interface can be accessed as shown below to import, compile and run a model. .. code:: python - target="opencl -device=adreno" + from tvm.driver import tvmc + from tvm.driver.tvmc.model import TVMCPackage + + # Convert a model from any framework to a tvm relay module. + # tvmc.load supports models from any framework (like tensorflow saves_model, onnx, tflite ..etc) and auto detects the filetype. + tvmc_model = tvmc.load("resnet50.h5") + + # tvmc_model consists of tvmc_mode.mod which is relay module and tvmc_model.params which parms of the module. + + # Now, the below api can be used for autotuning the model for any target. Tuning required RPC setup and please refer to + # :ref:`RPC Setup` for the same. + + tvmc.tune( + tvmc_model, + target="opencl -device=adreno", + output="keras-resnet50.log", + tuning_records="keras-resnet50-records.log", + target_host="llvm -mtriple=aarch64-linux-gnu" + rpc_tracker="127.0.0.1:9120", + rpc_key=android, + repeat=30, + trials=1024, + early_stopping=0, + ) + + # Compilation to produce tvm artifacts -Let's write a simple model with one convolutional (conv2d) layer and take a look at generated kernels for these -two targets + tvmc_package = tvmc.compile( + tvmc_model, + target="opencl -device=adreno", + target_host="llvm -mtriple=aarch64-linux-gnu", + cross="/android_ndk}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang", + tuning_records="keras-resnet50.log", + ) + + # tvmc_package consists of tvmc_package.lib_path, tvmc_package.graph, tvmc_package.params + + # Altrernatively, we can ave the cmpilation output and save it as a TVMCPackage. + # This way avoids loading of compiled module without compiling again. + + tvmc.compile( + tvmc_model, + target="opencl -device=adreno", + target_host="llvm -mtriple=aarch64-linux-gnu", + cross="/android_ndk/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang", + tuning_records="keras-resnet50.log", + package_path="keras-resnet50.tar" + ) + # Load the compiled package + tvmc_package = TVMCPackage(package_path=module_file) + + # Saved TVMPackage is nothing but tar archive with mod.so, mod.json and mod.params. + + # Deploy and run the compiled model on RPC + # Prepare input data dict + input_data = tvm.nd.array((np.random.uniform(size=(1, 229, 229, 3))).astype("float32")) + input_dict = {"input": input_data} + + # Run on RPC setup + result = tvmc.run( + tvmc_package, + device="cl", + rpc_key="android", + hostname="127.0.0.1", + port=9120, + inputs=input_dict + ) + + # result is a dictionary of outputs. + + +tvmc compiled package can be used for native deploy also using "rtvm" utility. +Please refer to `rtvm `_ for more details about this tool. + +Also, please refer to tvmc documentation for more details about the api interface. + +**Relay Interface:** + +Relay api interface gives lower level api access to the tvm compiler interface. +Relay interface follows tvmc kind os a flow where we produce TVM module first followed by auto tuning, compilation and deployment. + +Below example explains about relay interface usage .. code:: python import tvm from tvm import relay + from tvm.relay.op.contrib import clml import numpy as np - input_shape=(1, 56, 56, 32) - filter_shape=(3, 3, 32, 64) - filter = np.random.rand(*filter_shape) + from tensorflow.keras.applications import InceptionV3 + import tensorflow as tf + + target = "opencl -device=adreno" + target_host = "llvm -mtriple=arm64-linux-android" + + # We first need to get a handle for a model from any framework. + # In this example we will prepare a keras InceptionV3 model + tf.keras.backend.clear_session() + keras_net = InceptionV3( + include_top=True, weights=None, input_shape=(299, 299, 3), classes=1000 + ) + input_info = {inceptionV3.input_names[0]: (1, 3, 299, 299)} + input_data = {inceptionV3.input_names[0], np.random.uniform(-1, -1, (1, 3, 299, 299)).astype("float32")} + from tensorflow.keras.layers import Input + from tensorflow.keras.models import Model + def get_bottom_top_model(model, layer_name): + layer = model.get_layer(layer_name) + bottom_input = model.layers[0].input + bottom_output = layer.output + bottom_model = Model(bottom_input, bottom_output) + return bottom_model + keras_model = get_bottom_top_model(keras_net, "predictions") + ref_output = keras_model.predict(data["input_1"].transpose(0, 2, 3, 1)) + + # Now we have a keras_model with input "input_1" with shape (1, 3, 299,299), output "predictions" and a reference output ref_output. + + # Lets import the model and get a relay module. TVM has frontend api for various frameworks under relay.frontend and now for keras + # model import we have relay.frontend.from_keras api. + mod, params = relay.frontend.from_keras(keras_model, input_info, layout="NCHW") + + # With relay module mod and parameters params we can not fo for tuning followed by compilation. + # The below few instructions can auto tune the relay module with xgboost being the tuner algorithm. + + # Auto Tuning process involces stages of extracting the tasks, defining tuning congiguration and + # tuning each task for best performing kernel configuration. + + # Auto Tuning Stage 1: Extract tunable tasks + tasks = autotvm.task.extract_from_program( + net, target=target, target_host=target_host, params=params + ) + + # Auto Tuning Stage 2: Define tuning configuration + tune_log = "adreno-resnet50.log" + tmp_log_file = tune_log + ".tmp" + measure_option = autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15), # Build the test kernel locally + runner=autotvm.RPCRunner( # The runner would be on a remote device. + "android", # RPC Key + host="127.0.0.1", # Tracker host + port=9120, # Tracker port + number=3, # Number of runs before averaging + timeout=600, # RPC Timeout + ), + ), + n_trail = 1024 # Number of iteration of training before choosing the best kernel config + early_stopping=False, # Do we apply early stopping when the loss is not minimizing + + # Iterate through each task and call the tuner + from tvm.autotvm.tuner import XGBTuner + for i, tsk in enumerate(reversed(tasks)): + tuner_obj = XGBTuner(tsk, loss_type="rank") + + tsk_trial = min(n_trial, len(tsk.config_space)) + tuner_obj.tune( + n_trial=tsk_trial, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(tsk_trial, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file), + ], + ) + # Pick the best performing kerl configurations from the overall log. + autotvm.record.pick_best(tmp_log_file, log_filename) + + + # Given we have relay module and it's best performing kernel configurations + # We can now go for compilation with tuned log or without tuning log if auto tuning is not enabled. + + if os.path.exists(tune_log): + with autotvm.apply_history_best(tune_log): + with tvm.transform.PassContext(opt_level=3): + # Enable CLML partitioning if required. + net = clml.partition_for_clml(net, params) + + lib = relay.build( + net, target=tvm.target.Target(target, host=target_host), params=params + ) + else: + with tvm.transform.PassContext(opt_level=3): + # Enable CLML partitioning if required. + net = clml.partition_for_clml(net, params) + lib = relay.build( + net, target=tvm.target.Target(target, host=target_host), params=params + ) - dtype="float32" - input = tvm.relay.var("input", shape=input_shape, dtype=dtype) - weight = tvm.relay.var("weight", shape=filter_shape, dtype=dtype) - D = relay.nn.conv2d(input, weight, padding=(1, 1), data_layout="NHWC", kernel_layout="HWIO", out_dtype=dtype) + # Compilation results a lib module and it has everything required to deploy on target. + # We can save the compiler artifacts as shoun below and reload them later without entire compilation. + lib.export_library("mod.so", ndk.create_shared) + with open("mod.json", "w") as fo: + fo.write(graph.json()) + with open("mod.params", "wb") as fo: + fo.write(runtime.save_param_dict(params)) - mod = relay.Function([input, weight], D) - params = { - "weight": tvm.nd.array(filter) - } + # We can prepare TVMPackage from above files by art archiveing the same. + # The tar archive can be used with tvmc tool or tvmc api interfae to deploy and run. + # The tar archive can be used with "rtvm" tool also for native deploy on target device. -Now compile our model with the classic OpenCL target and print its modules: + # Now, lets look at deploying the compiled tvm artifact on remote target and run + tmp = tempdir() + filename = "%s.so" % network + lib.export_library(tmp.relpath(filename), ndk.create_shared) -.. code:: python + # connect to remote device + tracker = tvm.rpc.connect_tracker("127.0.0.1", 9120) + remote = tracker.request("android") + dev = remote.device(str(target), 0) + remote.upload(tmp.relpath(filename)) + rlib = remote.load_module(filename) - target="opencl" + # Create Graph runtime module on remote device + module = runtime.GraphModule(rlib["default"](dev)) + # Set input + module.set_input("input_1", input_data["input_1"]) + # Get output + output = module.get_output(0) - with tvm.transform.PassContext(opt_level=3): - graph, lib, params = relay.build_module.build(mod, target, params=params) - print(lib.imported_modules[0].get_source()) -Notice that the generated convolution kernel has pointers in -the initialization of the function. The kernels generated with the above target are buffer-based. +.. _application_integration: + +Aplication Integration: +---------------------- + +TVM compilation output is represented as module shared lib (mod.so), graph json(mod.json) and params (mod.params). +Archived representation of TVMPackage is also contains the same. + +In general a CPP/C based interface will be sufficient for any Android application integration. + +TVM natively expose c_runtime_api for loading a TVM compiled module and run the same. + +Alternatively one may refer to `cpp_rtvm `_ +tvm_runner interface too for further simplified version of the same. -.. code:: c - __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* restrict p0, __global double* restrict p1, __global float* restrict conv2d_nhwc) { - // body.. +.. _advanced_usage: -Now take a look at “opencl -device=adreno” target: +Advanced Usage: +--------------- + +This section details some of the advanced usage and additional information whihc using Adreno™ target on TVM. + +Generated Source Inspection +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Apart from standard tvm compilation artifacts kernel library (mod.so), graph (mod.json) and params (mod.params) +we can also generate opencl kernel source, clml offloaded graph ...etc from lib handle as shown below. +TVM compilation output is organized as a TVM module and many other TVM modules imported into it. + +Below snippet can dump CLML sub graphs in json format. .. code:: python - target="opencl -device=adreno" + # Look for "clml" typed module impoted. + clml_modules = list(filter(lambda mod: mod.type_key == "clml", lib.get_lib().imported_modules)) + # Loop throught all clml sub graphs and dump the json formatted CLML sub graphs. + for cmod in clml_modules: + print("CLML Src:", cmod.get_source()) - with tvm.transform.PassContext(opt_level=3): - graph, lib, params = relay.build_module.build(mod, target, params=params) - print(lib.imported_modules[0].get_source()) -The kernels generated this way is actually working with 2d arrays, leveraging textures +Similarly, below snippet can extract opencl kernel source from the compiled TVM module. + +.. code:: python + + # Similarly we can dump open kernel source too as shown below + # Look for "opencl" typed module impoted. + opencl_modules = list(filter(lambda mod: mod.type_key == "opencl", lib.get_lib().imported_modules)) + # Now dump open cource for each opencl targetted sub graph. + for omod in opencl_modules: + print("OpenCL Src:", omod.get_source()) + + +Inspecting above code for target device "opencl --device=adreno" shows texture usage (image2d_t) as shown below. .. code:: c @@ -214,28 +766,14 @@ We can choose from *float16*, *float16_acc32* (Mixed Precision), *float32* (stan To leverage the GPU hardware capabilities and utilize the benefits of half precision computation and memory management, we can convert an original model having floating points operation to a model operating with half precision. Choosing lower precision will positively affect the performance of the model, but it may also have a decrease in the accuracy of the model. -To do the conversion you need to write a simple conversion function and specify the *dtype* value of "float16" before calling the function: + +To do the conversion you need to call adreno specific transformation API as soon relay module is generated through any frontend: .. code:: python - def convert_to_dtype(mod, dtype): - # downcast to float16 - if dtype == "float16": - global conv2d_acc = "float16" - from tvm.ir import IRModule - mod = IRModule.from_expr(mod) - seq = tvm.transform.Sequential( - [ - relay.transform.InferType(), - relay.transform.ToMixedPrecision() - ] - ) - with tvm.transform.PassContext(opt_level=3): - mod = seq(mod) - return mod - - dtype="float16" - mod = convert_to_dtype(mod["main"], dtype) + from tvm.relay.op.contrib import adreno + adreno.convert_to_dtype(mod["main"], "float16") + We then can compile our model in any convinient way @@ -246,6 +784,7 @@ We then can compile our model in any convinient way mod, target_host=target_host, target=target, params=params ) + **float16_acc32 (Mixed Precision)** ToMixedPrecision pass traverse over the network and split network to clusters of ops dealing with float or float16 data types. @@ -255,75 +794,21 @@ The clusters are defined by three types of operations: - Operations never be converted to the float16 data type This list is defined in the ToMixedPrecision implementation here `relay/transform/mixed_precision.py `_ -and can be overridden by user +and can be overridden by user. -In some cases, we want higher precision in accumulation than the input data. -This is supported, for example, for conv2d and dense operations. To override accumulation type you need to register -function with ``@register_mixed_precision_conversion`` decorator to modify parameters of ``ToMixedPrecision`` conversion - -.. code:: python +The ``ToMixedPrecision`` method is a pass to convert an FP32 relay graph into an FP16 version (with +FP16 or FP32 accumulation dtypes). Doing this transformation is useful for reducing model size +as it halves the expected size of the weights (FP16_acc16 case). - from tvm.relay.op import register_mixed_precision_conversion - - conv2d_acc = "float32" - - # Pick a priority > 10 to overwrite defaults, higher priorities take precedence - @register_mixed_precision_conversion("nn.conv2d", level=11) - def conv2d_mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str): - global conv2d_acc - return [ - # always do main calculation in mixed_precision_type - relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS, - # the dtype for the accumulator - conv2d_acc, - # the output dtype for the operation (usually fp16) - mixed_precision_type, - ] - - # Same for dense - @register_mixed_precision_conversion("nn.dense", level=11) - def conv2d_mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str): - global conv2d_acc - return [ - relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS, - conv2d_acc, - mixed_precision_type, - ] - -Now we need to modify the conversion function by adding some logical "forks" and ToMixedPrecision() call, -then create a Relay graph from desired model in any convinient way and obtain **mod** (which is IR representation of the model), -after which we can convert it to the required **dtype** and then assemble our model sequentialy +ToMixedPrecision pass usage is simplified into a simple call as shown below for usage. .. code:: python - def convert_to_dtype(mod, dtype): - # downcast to float16 - if dtype == "float16" or dtype == "float16_acc32": - global conv2d_acc - conv2d_acc = "float16" if dtype == "float16" else "float32" - from tvm.ir import IRModule - mod = IRModule.from_expr(mod) - seq = tvm.transform.Sequential( - [ - relay.transform.InferType(), - relay.transform.ToMixedPrecision() - ] - ) - with tvm.transform.PassContext( - config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}, - opt_level=3): - mod = seq(mod) - return mod + from tvm.relay.op.contrib import adreno + adreno.convert_to_dtype(mod["main"], "float16_acc32") - dtype="float16_acc32" - mod = convert_to_dtype(mod["main"], dtype) - dtype = "float32" if dtype == "float32" else "float16" -The ``ToMixedPrecision`` method is a pass to convert an FP32 relay graph into an FP16 version (with -FP16 or FP32 accumulation dtypes). Doing this transformation is useful for reducing model size -as it halves the expected size of the weights (FP16_acc16 case). - -From this point onwards, we can compile our model as normal +We then can compile our model in any convinient way .. code:: python @@ -332,5 +817,5 @@ From this point onwards, we can compile our model as normal mod, target_host=target_host, target=target, params=params ) -.. |High-level overview of the Adreno A5x architecture for OpenCL| image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/how-to/adreno_architecture.png +.. |High-level overview of the Adreno™ A5x architecture for OpenCL| image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/how-to/adreno_architecture.png .. |Android deployment pipeline| image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/how-to/android_deployment_pipeline.jpg diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 01708e8452bd..104a48cb54bb 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -27,3 +27,4 @@ from .tensorrt import * from .cutlass import * from .clml import * +from .adreno import * diff --git a/python/tvm/relay/op/contrib/adreno.py b/python/tvm/relay/op/contrib/adreno.py new file mode 100644 index 000000000000..965b9b387c9a --- /dev/null +++ b/python/tvm/relay/op/contrib/adreno.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Adreno specific helpers.""" +import tvm + +from tvm import relay +from tvm.ir import IRModule + +acc_dtype = "float32" + + +def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str): + global acc_dtype + return [ + relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS, + acc_dtype, + mixed_precision_type, + ] + + +class AdrenoMixedPrecision(object): + """Temporarily changes attr of ops to enable FP32 accumulation .""" + + def __init__(self): + """Saves the required info for RAII pattern usage. + + Parameters + ---------- + acc_dtype : atr + accumulation dtype. + """ + self.older_attr = {} + self.ops = ["nn.conv2d", "nn.dense"] + self.attr_key = "FTVMMixedPrecisionConversionType" + + def __enter__(self): + for op_name in self.ops: + op = relay.op.get(op_name) + self.older_attr[op_name] = op.get_attr(self.attr_key) + op.reset_attr(self.attr_key) + op.set_attr(self.attr_key, mixed_precision_rule) + return self + + def __exit__(self, ptype, value, trace): + for op_name in self.ops: + op = relay.op.get(op_name) + op.reset_attr(self.attr_key) + if self.older_attr[op_name]: + op.set_attr(self.attr_key, self.older_attr[op_name]) + + +def convert_to_dtype(mod, dtype): + """Converts the operator datatypes""" + + global acc_dtype + if dtype in ["float16", "float16_acc32"]: + acc_dtype = "float16" if dtype == "float16" else "float32" + + mod = IRModule.from_expr(mod) + with AdrenoMixedPrecision(): + seq = tvm.transform.Sequential( + [relay.transform.InferType(), relay.transform.ToMixedPrecision()] + ) + with tvm.transform.PassContext( + config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}, opt_level=3 + ): + mod = seq(mod) + else: + print("Warn: Invald dtype conversion to ", dtype) + return mod diff --git a/tests/python/relay/opencl_texture/test_network.py b/tests/python/relay/opencl_texture/test_network.py index 46ee79697ea6..47bd82a2d1f1 100644 --- a/tests/python/relay/opencl_texture/test_network.py +++ b/tests/python/relay/opencl_texture/test_network.py @@ -22,31 +22,20 @@ import tvm from tvm import relay from tvm.contrib import utils +from tvm.relay.op.contrib import adreno from tvm.relay import testing from tvm.relay.op import register_mixed_precision_conversion from utils.adreno_utils import build_run_compare, get_model, gpu_preprocess -def convert_to_fp16(mod, dtype): - from tvm.ir import IRModule - - mod = IRModule.from_expr(mod) - seq = tvm.transform.Sequential( - [relay.transform.InferType(), relay.transform.ToMixedPrecision()] - ) - with tvm.transform.PassContext(opt_level=3): - mod = seq(mod) - return mod - - def _test_mobilenet_v1(remote, target, dtype): mod, params, inputs, dtypes = get_model( "https://github.com/mlcommons/mobile_models/raw/main/v0_7/tflite/mobilenet_edgetpu_224_1.0_float.tflite", "mobilenet_edgetpu_224_1.0_float.tflite", "tflite", ) - if dtype == "float16": - mod = convert_to_fp16(mod["main"], dtype) + if dtype == "float16" or dtype == "float16_acc32": + mod = adreno.convert_to_dtype(mod["main"], dtype) build_run_compare(remote, mod, params, inputs, dtypes, target, []) @@ -65,5 +54,12 @@ def test_mobilenet_v1_fp32(remote, target): _test_mobilenet_v1(remote, target, "float32") +@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/13443") +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_mobilenet_v1_fp16_acc32(remote, target): + _test_mobilenet_v1(remote, target, "float16_acc32") + + if __name__ == "__main__": tvm.testing.main() From ecde695d985817d6561055329a5b50bfdce7583c Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 1 Feb 2023 12:50:53 +0530 Subject: [PATCH 02/20] Update docs/how_to/deploy/adreno.rst Co-authored-by: Egor Churaev --- docs/how_to/deploy/adreno.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index bdee7e597d7b..066590f6956b 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -16,7 +16,7 @@ under the License. Deploy to Adreno™ GPU -==================== +===================== **Authors**: Daniil Barinov, Egor Churaev, Andrey Malyshev, Siva Rama Krishna From 9194d9e0a1f26f1a9388e469e83a72cdc94d285f Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 1 Feb 2023 12:52:35 +0530 Subject: [PATCH 03/20] Update docs/how_to/deploy/adreno.rst Co-authored-by: Egor Churaev --- docs/how_to/deploy/adreno.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index 066590f6956b..7b8ee553c9b9 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -94,7 +94,7 @@ These operators are exposed as an extension "cl_qcom_ml_ops" to standard OpenCL Please refer `Accelerate your models with our OpenCL ML SDK `_ for more details. OpenCLML is integrated into TVM as a `BYOC `_ solution. -OpenCLML operators can use same context and the operatrors can be enqueued on same command queue if native OpenCL. +OpenCLML operators can use same context and can be enqueued on same command queue as used in native OpenCL. We took advantage of this to avoid any context switching over heads while fallback to native OpenCL. From d4c1c7c498c418f0aaf3e37a1fdbc6694a4573b7 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Wed, 1 Feb 2023 20:03:53 +0530 Subject: [PATCH 04/20] * review --- docs/how_to/deploy/adreno.rst | 252 +++-------------- .../deploy_models/deploy_model_on_adreno.py | 254 +++++++++++------- .../deploy_model_on_adreno_tvmc.py | 186 +++++++++++++ 3 files changed, 368 insertions(+), 324 deletions(-) create mode 100644 gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index 7b8ee553c9b9..657def61a110 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -44,12 +44,6 @@ This guide is organized to demonstrate various design aspects of - :ref:`Build and Deploy` - -.. how to :ref:`build TVM with OpenCL` (needed by Adreno™ devices) and TVM RPC -.. enabled. It will also provide :ref:`example code` to better understand the differences in compiling and deploying models -.. for Adreno™ devices. - - .. _opencl_enhancements: OpenCL Backend Enhancements @@ -84,6 +78,29 @@ Reasons of using textures: Overall, with textures, it is possible to achieve a significant performance boost compared to OpenCL buffer based solutions. +In general we specify target as ``target="opencl"`` for a regular OpenCL based target which generates the kernels as shown below. + +.. code:: c + + __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* restrict p0, __global double* restrict p1, __global float* restrict conv2d_nhwc) { + // body.. + +Above OpenCL kernel definition has ``__global float*`` poniters which are essestially OpenCL ``buffer`` objects. + +When enabled texture based enhancements by modifying target definition as ``target="opencl -device=adreno"`` we can see the generated +kernels using texture backed OpenCL image objects as shown below. + +.. code:: c + + __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t pad_temp_global_texture, __read_only image2d_t p0) { + // body.. + +*image2d_t* is a built-in OpenCL types that represents two-dimensional image object and provides several additional functions. +When we use *image2d_t* we read *4 elements at one time*, and it helps to utilize hardware in a more efficient way. + +Please refer to :ref:`Advanced Usage` for more details about generation and inspection of kernel sources. + + .. _about_openclml: About OpenCLML @@ -454,7 +471,7 @@ We can use below tvmc command to deploy on remore target via RPC based setup. tvmc based run has more option to initialize the input in various modes line fill, random ..etc. -TVM also supports "rtvm" tool to run the model narivelu on ADB shell. The build process produced this tool under build-adreno-target. +TVM also supports "rtvm" tool to run the model narively on ADB shell. The build process produced this tool under build-adreno-target. Please refer to `rtvm `_ for more details about this tool. @@ -468,233 +485,26 @@ to a relay module. Relay module will be used across the auto tuning, compilation **TVMC Interface:** -TVMC interface can be accessed as shown below to import, compile and run a model. - -.. code:: python - - from tvm.driver import tvmc - from tvm.driver.tvmc.model import TVMCPackage - - # Convert a model from any framework to a tvm relay module. - # tvmc.load supports models from any framework (like tensorflow saves_model, onnx, tflite ..etc) and auto detects the filetype. - tvmc_model = tvmc.load("resnet50.h5") - - # tvmc_model consists of tvmc_mode.mod which is relay module and tvmc_model.params which parms of the module. - - # Now, the below api can be used for autotuning the model for any target. Tuning required RPC setup and please refer to - # :ref:`RPC Setup` for the same. - - tvmc.tune( - tvmc_model, - target="opencl -device=adreno", - output="keras-resnet50.log", - tuning_records="keras-resnet50-records.log", - target_host="llvm -mtriple=aarch64-linux-gnu" - rpc_tracker="127.0.0.1:9120", - rpc_key=android, - repeat=30, - trials=1024, - early_stopping=0, - ) - - # Compilation to produce tvm artifacts - - tvmc_package = tvmc.compile( - tvmc_model, - target="opencl -device=adreno", - target_host="llvm -mtriple=aarch64-linux-gnu", - cross="/android_ndk}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang", - tuning_records="keras-resnet50.log", - ) - - # tvmc_package consists of tvmc_package.lib_path, tvmc_package.graph, tvmc_package.params - - # Altrernatively, we can ave the cmpilation output and save it as a TVMCPackage. - # This way avoids loading of compiled module without compiling again. - - tvmc.compile( - tvmc_model, - target="opencl -device=adreno", - target_host="llvm -mtriple=aarch64-linux-gnu", - cross="/android_ndk/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang", - tuning_records="keras-resnet50.log", - package_path="keras-resnet50.tar" - ) - # Load the compiled package - tvmc_package = TVMCPackage(package_path=module_file) - - # Saved TVMPackage is nothing but tar archive with mod.so, mod.json and mod.params. - - # Deploy and run the compiled model on RPC - # Prepare input data dict - input_data = tvm.nd.array((np.random.uniform(size=(1, 229, 229, 3))).astype("float32")) - input_dict = {"input": input_data} - - # Run on RPC setup - result = tvmc.run( - tvmc_package, - device="cl", - rpc_key="android", - hostname="127.0.0.1", - port=9120, - inputs=input_dict - ) - - # result is a dictionary of outputs. - +TVMC interface can be accessed as shown below to import, compile and run a model. Please refer to the tutorial for the same +`How To Deploy model on Adreno using TVMC `_ tvmc compiled package can be used for native deploy also using "rtvm" utility. -Please refer to `rtvm `_ for more details about this tool. Also, please refer to tvmc documentation for more details about the api interface. **Relay Interface:** Relay api interface gives lower level api access to the tvm compiler interface. -Relay interface follows tvmc kind os a flow where we produce TVM module first followed by auto tuning, compilation and deployment. - -Below example explains about relay interface usage +Relay interface follows tvmc kind of a flow where we produce TVM module first followed by auto tuning, compilation and deployment. -.. code:: python - - import tvm - from tvm import relay - from tvm.relay.op.contrib import clml - import numpy as np - - from tensorflow.keras.applications import InceptionV3 - import tensorflow as tf - - target = "opencl -device=adreno" - target_host = "llvm -mtriple=arm64-linux-android" - - # We first need to get a handle for a model from any framework. - # In this example we will prepare a keras InceptionV3 model - tf.keras.backend.clear_session() - keras_net = InceptionV3( - include_top=True, weights=None, input_shape=(299, 299, 3), classes=1000 - ) - input_info = {inceptionV3.input_names[0]: (1, 3, 299, 299)} - input_data = {inceptionV3.input_names[0], np.random.uniform(-1, -1, (1, 3, 299, 299)).astype("float32")} - from tensorflow.keras.layers import Input - from tensorflow.keras.models import Model - def get_bottom_top_model(model, layer_name): - layer = model.get_layer(layer_name) - bottom_input = model.layers[0].input - bottom_output = layer.output - bottom_model = Model(bottom_input, bottom_output) - return bottom_model - keras_model = get_bottom_top_model(keras_net, "predictions") - ref_output = keras_model.predict(data["input_1"].transpose(0, 2, 3, 1)) - - # Now we have a keras_model with input "input_1" with shape (1, 3, 299,299), output "predictions" and a reference output ref_output. - - # Lets import the model and get a relay module. TVM has frontend api for various frameworks under relay.frontend and now for keras - # model import we have relay.frontend.from_keras api. - mod, params = relay.frontend.from_keras(keras_model, input_info, layout="NCHW") - - # With relay module mod and parameters params we can not fo for tuning followed by compilation. - # The below few instructions can auto tune the relay module with xgboost being the tuner algorithm. - - # Auto Tuning process involces stages of extracting the tasks, defining tuning congiguration and - # tuning each task for best performing kernel configuration. - - # Auto Tuning Stage 1: Extract tunable tasks - tasks = autotvm.task.extract_from_program( - net, target=target, target_host=target_host, params=params - ) - - # Auto Tuning Stage 2: Define tuning configuration - tune_log = "adreno-resnet50.log" - tmp_log_file = tune_log + ".tmp" - measure_option = autotvm.measure_option( - builder=autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15), # Build the test kernel locally - runner=autotvm.RPCRunner( # The runner would be on a remote device. - "android", # RPC Key - host="127.0.0.1", # Tracker host - port=9120, # Tracker port - number=3, # Number of runs before averaging - timeout=600, # RPC Timeout - ), - ), - n_trail = 1024 # Number of iteration of training before choosing the best kernel config - early_stopping=False, # Do we apply early stopping when the loss is not minimizing - - # Iterate through each task and call the tuner - from tvm.autotvm.tuner import XGBTuner - for i, tsk in enumerate(reversed(tasks)): - tuner_obj = XGBTuner(tsk, loss_type="rank") - - tsk_trial = min(n_trial, len(tsk.config_space)) - tuner_obj.tune( - n_trial=tsk_trial, - early_stopping=early_stopping, - measure_option=measure_option, - callbacks=[ - autotvm.callback.progress_bar(tsk_trial, prefix=prefix), - autotvm.callback.log_to_file(tmp_log_file), - ], - ) - # Pick the best performing kerl configurations from the overall log. - autotvm.record.pick_best(tmp_log_file, log_filename) - - - # Given we have relay module and it's best performing kernel configurations - # We can now go for compilation with tuned log or without tuning log if auto tuning is not enabled. - - if os.path.exists(tune_log): - with autotvm.apply_history_best(tune_log): - with tvm.transform.PassContext(opt_level=3): - # Enable CLML partitioning if required. - net = clml.partition_for_clml(net, params) - - lib = relay.build( - net, target=tvm.target.Target(target, host=target_host), params=params - ) - else: - with tvm.transform.PassContext(opt_level=3): - # Enable CLML partitioning if required. - net = clml.partition_for_clml(net, params) - lib = relay.build( - net, target=tvm.target.Target(target, host=target_host), params=params - ) - - # Compilation results a lib module and it has everything required to deploy on target. - # We can save the compiler artifacts as shoun below and reload them later without entire compilation. - lib.export_library("mod.so", ndk.create_shared) - with open("mod.json", "w") as fo: - fo.write(graph.json()) - with open("mod.params", "wb") as fo: - fo.write(runtime.save_param_dict(params)) - - # We can prepare TVMPackage from above files by art archiveing the same. - # The tar archive can be used with tvmc tool or tvmc api interfae to deploy and run. - # The tar archive can be used with "rtvm" tool also for native deploy on target device. - - # Now, lets look at deploying the compiled tvm artifact on remote target and run - tmp = tempdir() - filename = "%s.so" % network - lib.export_library(tmp.relpath(filename), ndk.create_shared) - - # connect to remote device - tracker = tvm.rpc.connect_tracker("127.0.0.1", 9120) - remote = tracker.request("android") - dev = remote.device(str(target), 0) - remote.upload(tmp.relpath(filename)) - rlib = remote.load_module(filename) - - # Create Graph runtime module on remote device - module = runtime.GraphModule(rlib["default"](dev)) - # Set input - module.set_input("input_1", input_data["input_1"]) - # Get output - output = module.get_output(0) +Please refer to the tutorial `How To Deploy model on Adreno `_ +for a step by step explanation of the same. .. _application_integration: Aplication Integration: ----------------------- +----------------------- TVM compilation output is represented as module shared lib (mod.so), graph json(mod.json) and params (mod.params). Archived representation of TVMPackage is also contains the same. @@ -713,7 +523,7 @@ tvm_runner interface too for further simplified version of the same. Advanced Usage: --------------- -This section details some of the advanced usage and additional information whihc using Adreno™ target on TVM. +This section details some of the advanced usage and additional information while using Adreno™ target on TVM. Generated Source Inspection ~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/gallery/how_to/deploy_models/deploy_model_on_adreno.py b/gallery/how_to/deploy_models/deploy_model_on_adreno.py index c120c5339b62..db2f78f69417 100644 --- a/gallery/how_to/deploy_models/deploy_model_on_adreno.py +++ b/gallery/how_to/deploy_models/deploy_model_on_adreno.py @@ -18,9 +18,9 @@ """ .. _tutorial-deploy-model-on-adreno: -Deploy the Pretrained Model on Adreno -======================================= -**Author**: Daniil Barinov +Deploy the Pretrained Model on Adreno™ +====================================== +**Author**: Daniil Barinov, Siva Rama Krishna This article is a step-by-step tutorial to deploy pretrained Pytorch ResNet-18 model on Adreno (on different precisions). @@ -115,6 +115,67 @@ # android 1 1 0 # ---------------------------------- +################################################################# +# Configuration +# ------------- + +import os +import torch +import torchvision +import tvm +from tvm import te +from tvm import relay, rpc +from tvm.contrib import utils, ndk +from tvm.contrib import graph_executor +from tvm.relay.op.contrib import clml +from tvm import autotvm + +# Adreno devices are efficient with float16 compared to float32 +# Given the expected output doesn't effect by lowering precision +# it's advisable to use lower precision. +# We have a helper API to make the precision conversion simple and +# it supports dtype with "float16" and "float16_acc32" modes. +# Let's choose "float16_acc32" for this example. + +dtype = "float16_acc32" + +# Specify Adreno target before compiling to generate texture +# leveraging kernels and get all the benefits of textures +# Note: This generated example running on our x86 server for demonstration. +# If running it on the Android device, we need to +# specify its instruction set. Set :code:`local_demo` to False if you want +# to run this tutorial with a real device over rpc. +local_demo = True + +# by default on CPU target will execute. +# select 'cpu', 'opencl' and 'opencl -device=adreno' +test_target = "cpu" + +# Change target configuration. +# Run `adb shell cat /proc/cpuinfo` to find the arch. +arch = "arm64" +target = tvm.target.Target("llvm -mtriple=%s-linux-android" % arch) + +# Auto tuning is compute and time taking task, hence disabling for default run. Please enable it if required. +is_tuning = False +tune_log = "adreno-resnet18.log" + +# To enable OpenCLML accelerated operator library. +enable_clml = False + +################################################################# +# Get a PyTorch Model +# ------------------- +# Get resnet18 from torchvision models +model_name = "resnet18" +model = getattr(torchvision.models, model_name)(pretrained=True) +model = model.eval() + +# We grab the TorchScripted model via tracing +input_shape = [1, 3, 224, 224] +input_data = torch.randn(input_shape) +scripted_model = torch.jit.trace(model, input_data).eval() + ################################################################# # Load a test image # ----------------- @@ -146,85 +207,24 @@ img = np.expand_dims(img, 0) ################################################################# -# Load pretrained Pytorch model -# ----------------------------- -# Create a Relay graph from a Pytorch ResNet-18 model -import os -import torch -import torchvision -import tvm -from tvm import te -from tvm import relay, rpc -from tvm.contrib import utils, ndk -from tvm.contrib import graph_executor - -model_name = "resnet18" -model = getattr(torchvision.models, model_name)(pretrained=True) -model = model.eval() - -# We grab the TorchScripted model via tracing -input_shape = [1, 3, 224, 224] -input_data = torch.randn(input_shape) -scripted_model = torch.jit.trace(model, input_data).eval() - +# Convert PyTorch model to Relay module +# ------------------------------------- +# TVM has frontend api for various frameworks under relay.frontend and now +# for pytorch model import we have relay.frontend.from_pytorch api. # Input name can be arbitrary input_name = "input0" shape_list = [(input_name, img.shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) ################################################################# # Precisions # ---------- -# Since TVM support Mixed Precision, we need to register mixed_precision_conversion: -from tvm.relay.op import register_mixed_precision_conversion - -conv2d_acc = "float32" - - -@register_mixed_precision_conversion("nn.conv2d", level=11) -def conv2d_mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str): - global conv2d_acc - return [ - relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS, - conv2d_acc, - mixed_precision_type, - ] - - -@register_mixed_precision_conversion("nn.dense", level=11) -def conv2d_mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str): - global conv2d_acc - return [ - relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS, - conv2d_acc, - mixed_precision_type, - ] +from tvm.relay.op.contrib import adreno +adreno.convert_to_dtype(mod["main"], dtype) -################################################################# -# and also define the conversion function itself -def convert_to_dtype(mod, dtype): - # downcast to float16 - if dtype == "float16" or dtype == "float16_acc32": - global conv2d_acc - conv2d_acc = "float16" if dtype == "float16" else "float32" - from tvm.ir import IRModule - - mod = IRModule.from_expr(mod) - seq = tvm.transform.Sequential( - [relay.transform.InferType(), relay.transform.ToMixedPrecision()] - ) - with tvm.transform.PassContext(opt_level=3): - mod = seq(mod) - return mod - - -################################################################# -# Let's choose "float16_acc32" for example. -dtype = "float16_acc32" -mod = convert_to_dtype(mod["main"], dtype) dtype = "float32" if dtype == "float32" else "float16" - print(mod) ################################################################# @@ -233,46 +233,96 @@ def convert_to_dtype(mod, dtype): # You can also use "float16" or "float32" precisions as other dtype options. ################################################################# -# Compile the model with relay -# ---------------------------- -# Specify Adreno target before compiling to generate texture -# leveraging kernels and get all the benefits of textures -# Note: This generated example running on our x86 server for demonstration. -# If running it on the Android device, we need to -# specify its instruction set. Set :code:`local_demo` to False if you want -# to run this tutorial with a real device. +# Prepare TVM Target +# ------------------ -local_demo = True +if local_demo: + target = tvm.target.Target("llvm") +elif test_target.find("opencl"): + target = tvm.target.Target(test_target, host=target) -# by default on CPU target will execute. -# select 'cpu', 'opencl' and 'vulkan' -test_target = "cpu" +################################################################## +# AutoTuning +# ---------- +# The below few instructions can auto tune the relay module with xgboost being the tuner algorithm. -# Change target configuration. -# Run `adb shell cat /proc/cpuinfo` to find the arch. -arch = "arm64" -target = tvm.target.Target("llvm -mtriple=%s-linux-android" % arch) +# Auto Tuning process involces stages of extracting the tasks, defining tuning congiguration and +# tuning each task for best performing kernel configuration. -if local_demo: - target = tvm.target.Target("llvm") -elif test_target == "opencl": - target = tvm.target.Target("opencl", host=target) -elif test_target == "vulkan": - target = tvm.target.Target("vulkan", host=target) +# Get RPC related settings. +rpc_tracker_host = os.environ.get("TVM_TRACKER_HOST", "127.0.0.1") +rpc_tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190)) +key = "android" + +if is_tuning: + # Auto Tuning Stage 1: Extract tunable tasks + tasks = autotvm.task.extract_from_program( + mod, target=test_target, target_host=target, params=params + ) + + # Auto Tuning Stage 2: Define tuning configuration + tmp_log_file = tune_log + ".tmp" + measure_option = autotvm.measure_option( + builder=autotvm.LocalBuilder( + build_func=ndk.create_shared, timeout=15 + ), # Build the test kernel locally + runner=autotvm.RPCRunner( # The runner would be on a remote device. + key, # RPC Key + host=rpc_tracker_host, # Tracker host + port=int(rpc_tracker_port), # Tracker port + number=3, # Number of runs before averaging + timeout=600, # RPC Timeout + ), + ) + n_trial = 1024 # Number of iteration of training before choosing the best kernel config + early_stopping = False # Do we apply early stopping when the loss is not minimizing + + # Iterate through each task and call the tuner + from tvm.autotvm.tuner import XGBTuner + + for i, tsk in enumerate(reversed(tasks[:3])): + print("Task:", tsk) + prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) + tuner_obj = XGBTuner(tsk, loss_type="rank") + + tsk_trial = min(n_trial, len(tsk.config_space)) + tuner_obj.tune( + n_trial=tsk_trial, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(tsk_trial, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file), + ], + ) + # Pick the best performing kerl configurations from the overall log. + autotvm.record.pick_best(tmp_log_file, tune_log) -with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target=target, params=params) +################################################################# +# Enable OpenCLML Offloading +# -------------------------- +# OpenCLML offloading will try to accelerate supported operators +# by using OpenCLML proprietory operator library. +if not local_demo and enable_clml: + mod = clml.partition_for_clml(mod, params) + +################################################################# +# Compilation +# ----------- +# Use tuning cache if exists. +if os.path.exists(tune_log): + with autotvm.apply_history_best(tune_log): + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) +else: + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) ################################################################# # Deploy the Model Remotely by RPC # -------------------------------- # Using RPC you can deploy the model from host # machine to the remote Adreno device - -rpc_tracker_host = os.environ.get("TVM_TRACKER_HOST", "127.0.0.1") -rpc_tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190)) -key = "android" - if local_demo: remote = rpc.LocalSession() else: @@ -282,10 +332,8 @@ def convert_to_dtype(mod, dtype): if local_demo: dev = remote.cpu(0) -elif test_target == "opencl": +elif test_target.find("opencl"): dev = remote.cl(0) -elif test_target == "vulkan": - dev = remote.vulkan(0) else: dev = remote.cpu(0) diff --git a/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py b/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py new file mode 100644 index 000000000000..193ef37c6fbe --- /dev/null +++ b/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _tutorial-deploy-model-on-adreno-tvmc: + +Deploy the Pretrained Model on Adreno™ with tvmc Interface +========================================================== +**Author**: Siva Rama Krishna + +This article is a step-by-step tutorial to deploy pretrained Keras resnet50 model on Adreno™. + +Besides that, you should have TVM built for Android. +See the following instructions on how to build it and setup RPC environment. + +`Deploy to Adreno GPU `_ + +""" + +import os +import tvm +import numpy as np +from tvm import relay +from tvm.driver import tvmc +from tvm.driver.tvmc.model import TVMCPackage +from tvm.contrib import utils + +################################################################# +# Configuration +# ------------- +# Specify Adreno target before compiling to generate texture +# leveraging kernels and get all the benefits of textures +# Note: This generated example running on our x86 server for demonstration. +# If running it on the Android device, we need to +# specify its instruction set. Set :code:`local_demo` to False if you want +# to run this tutorial with a real device over rpc. +local_demo = True + +# by default on CPU target will execute. +# select 'llvm', 'opencl' and 'opencl -device=adreno' +target = "llvm" + +# Change target configuration. +# Run `adb shell cat /proc/cpuinfo` to find the arch. +arch = "arm64" +target_host = "llvm -mtriple=%s-linux-android" % arch + +# Auto tuning is compute and time taking task, hence disabling for default run. Please enable it if required. +is_tuning = False +tune_log = "adreno-resnet50.log" + +# To enable OpenCLML accelerated operator library. +enable_clml = False +cross_compiler = "/opt/android-sdk-linux/ndk/21.3.6528147/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" + +####################################################################### +# Make a Keras Resnet50 Model +# --------------------------- + +from tensorflow.keras.applications.resnet50 import ResNet50 + +tmp_path = utils.tempdir() +model_file_name = tmp_path.relpath("resnet50.h5") + +model = ResNet50(include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000) +model.save(model_file_name) + + +####################################################################### +# Load Model +# ---------- +# Convert a model from any framework to a tvm relay module. +# tvmc.load supports models from any framework (like tensorflow saves_model, onnx, tflite ..etc) and auto detects the filetype. + +tvmc_model = tvmc.load(model_file_name) + +print(tvmc_model.mod) + +# tvmc_model consists of tvmc_mode.mod which is relay module and tvmc_model.params which parms of the module. + +####################################################################### +# AutoTuning +# ---------- +# Now, the below api can be used for autotuning the model for any target. +# Tuning required RPC setup and please refer to +# `Deploy to Adreno GPU `_ + +rpc_tracker_host = os.environ.get("TVM_TRACKER_HOST", "127.0.0.1") +rpc_tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190)) +rpc_key = "android" +rpc_tracker = rpc_tracker_host + ":" + str(rpc_tracker_port) + + +if is_tuning: + tvmc.tune( + tvmc_model, + target=target, + tuning_records=tune_log, + target_host=target_host, + hostname=rpc_tracker_host, + port=rpc_tracker_port, + rpc_key=rpc_key, + tuner="xgb", + repeat=30, + trials=3, + early_stopping=0, + ) + +####################################################################### +# Compilation +# ----------- +# Compilation to produce tvm artifacts + + +if not enable_clml: + if local_demo: + tvmc_package = tvmc.compile( + tvmc_model, + target=target, + ) + else: + tvmc_package = tvmc.compile( + tvmc_model, + target=target, + target_host=target_host, + cross=cross_compiler, + tuning_records=tune_log, + ) +else: + # Altrernatively, we can save the compilation output and save it as a TVMCPackage. + # This way avoids loading of compiled module without compiling again. + target = target + ", clml" + pkg_path = tmp_path.relpath("keras-resnet50.tar") + tvmc.compile( + tvmc_model, + target=target, + target_host=target_host, + cross=cross_compiler, + tuning_records=tune_log, + package_path=pkg_path, + ) + + # Load the compiled package + tvmc_package = TVMCPackage(package_path=pkg_path) + +# tvmc_package consists of tvmc_package.lib_path, tvmc_package.graph, tvmc_package.params +# Saved TVMPackage is nothing but tar archive with mod.so, mod.json and mod.params. + + +####################################################################### +# Deploy & Run +# ------------ +# Deploy and run the compiled model on RPC +# Prepare input data dict +input_data = tvm.nd.array((np.random.uniform(size=(1, 224, 224, 3))).astype("float32")) +input_dict = {"input_1": input_data} + +# Run on RPC setup +if local_demo: + result = tvmc.run(tvmc_package, device="cpu", inputs=input_dict) +else: + result = tvmc.run( + tvmc_package, + device="cl", + rpc_key=rpc_key, + hostname=rpc_tracker_host, + port=rpc_tracker_port, + inputs=input_dict, + ) + +# result is a dictionary of outputs. +print("Result:", result) From f1b369f73cc1dad11b162e911b52c22ea7a213ba Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Thu, 2 Feb 2023 20:25:32 +0530 Subject: [PATCH 05/20] * fix --- .../how_to/deploy_models/deploy_model_on_adreno_tvmc.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py b/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py index 193ef37c6fbe..973663d1f20c 100644 --- a/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py +++ b/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py @@ -165,13 +165,11 @@ # Deploy & Run # ------------ # Deploy and run the compiled model on RPC -# Prepare input data dict -input_data = tvm.nd.array((np.random.uniform(size=(1, 224, 224, 3))).astype("float32")) -input_dict = {"input_1": input_data} +# Let tvmc fill inputs using random # Run on RPC setup if local_demo: - result = tvmc.run(tvmc_package, device="cpu", inputs=input_dict) + result = tvmc.run(tvmc_package, device="cpu", fill_mode="random") else: result = tvmc.run( tvmc_package, @@ -179,7 +177,7 @@ rpc_key=rpc_key, hostname=rpc_tracker_host, port=rpc_tracker_port, - inputs=input_dict, + fill_mode="random", ) # result is a dictionary of outputs. From 3f9609bab9f47ef6be45781a966f7684bba82b34 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Mon, 13 Feb 2023 12:33:48 +0530 Subject: [PATCH 06/20] * review --- docs/how_to/deploy/adreno.rst | 77 +++++-------------- .../deploy_models/deploy_model_on_adreno.py | 61 ++++++++++++--- .../deploy_model_on_adreno_tvmc.py | 16 +++- tests/scripts/task_build_adreno_bins.sh | 1 - tests/scripts/task_config_build_adreno.sh | 1 - 5 files changed, 82 insertions(+), 74 deletions(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index 657def61a110..b0c0bb8822e9 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -127,7 +127,12 @@ Deploying the compiled model here require use some tools on host as well as on t TVM has simplified user friendly command line based tools as well as developer centric python API interface for various steps like auto tuning, building and deploying. -TVM compilation process for remote devices has multiple stages listed below. + +|Adreno deployment pipeline| + +*Fig.2 Build and Deployment pipeline on Adreno devices* + +The figure above demonstrates a generalized pipeline for various stages listed below. **Model import:** At this stage we import a model from well known frameworks like Tensorflow, PyTorch, ONNX ...etc. @@ -150,7 +155,7 @@ At this stage we run the TVM compilation output on the target. Deployment is pos environment using RPC Setup and also using TVM's native tool which is native binary cross compiled for Android. At this stage we can run the compiled model on Android target and unit test output correctness and performance aspects. -**Aplication Integration:** +**Application Integration:** This stage is all about integrating TVM compiled model in applications. Here we discuss about interfacing tvm runtime from Android (cpp native environment or from JNI) for setting input and getting output. @@ -234,7 +239,6 @@ Below command will configure the build the host compiler cd build cp ../cmake/config.cmake . - echo set\(USE_OPENCL ON\) >> config.cmake echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake @@ -258,7 +262,7 @@ Finally we can export python path as :: - export PYTHONPATH=$PWD:/python + export PYTHONPATH=$TVM_HOME/python:${PYTHONPATH} python3 -c "import tvm" # Verify tvm python package @@ -274,7 +278,6 @@ Target build require Android NDK to be installed. mkdir -p build-adreno cd build-adreno cp ../cmake/config.cmake . - echo set\(USE_MICRO OFF\) >> config.cmake echo set\(USE_OPENCL ON\) >> config.cmake echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_CPP_RPC ON\) >> config.cmake @@ -342,12 +345,12 @@ manually and also inside docker using automated tools. **Automated RPC Setup:** Here we will explain how to setup RPC in docker environment. -Below command launches tracker in docker environment, where docker listens on port 9120. +Below command launches tracker in docker environment, where tracker listens on port 9190. :: ./tests/scripts/ci.py adreno -i # Launch a new shell on the anreno docker - source tests/scripts/setup-adreno-env.sh -e tracker -p 9120 + source tests/scripts/setup-adreno-env.sh -e tracker -p 9190 Now, the below comand can run TVM RPC on remote android device with id "abcdefgh". @@ -355,60 +358,16 @@ Now, the below comand can run TVM RPC on remote android device with id "abcdefgh :: ./tests/scripts/ci.py adreno -i # Launch a new shell on adreno docker. - source tests/scripts/setup-adreno-env.sh -e device -p 9120 -d abcdefgh + source tests/scripts/setup-adreno-env.sh -e device -p 9190 -d abcdefgh **Manual RPC Setup:** -Below command in manual setup starts the tracker on port 9120 - -:: - - python3 -m tvm.exec.rpc_tracker --host "0.0.0.0" --port "9120" - -TVM RPC launch on Android device require some environment setup due to Android device is connected via ADB interface and we need to re-route -TCP/IP communication over ADB interface. Below commands will do necessary setup and run tvm_rpc on remote device. - -:: - - # Set android device to use - export ANDROID_SERIAL=abcdefgh - # Create a temporary folder on remote device. - adb shell "mkdir -p /data/local/tmp/tvm_ci" - # Copy tvm_rpc and it's dependency to remote device - adb push build-adreno-target/tvm_rpc /data/local/tmp/tvm_test/tvm_rpc - adb push build-adreno-target/libtvm_runtime.so /data/local/tmp/tvm_test - # Forward port 9120 from target to host - adb reverse tcp:9210 tcp:9120 - # tvm_rpc by default listens on ports starting from 5000 for incoming connections. - # Hence, reroute connections to these ports on host to remore device. - adb forward tcp:5000 tcp:5000 - adb forward tcp:5001 tcp:5001 - adb forward tcp:5002 tcp:5002 - # Finally launch rpc_daemon on remote device with identity key as "android" - adb shell "cd /data/local/tmp/tvm_test; killall -9 tvm_rpc; sleep 2; LD_LIBRARY_PATH=/data/local/tmp/tvm_test/ ./tvm_rpc server --host=0.0.0.0 --port=5000 --port-end=5010 --tracker=127.0.0.1:9120 --key=android" - -Upon successfull running this remote device will be available on tracker which can be queried as below. - -:: - - python3 -m tvm.exec.query_rpc_tracker --port 9120 - Tracker address 127.0.0.1:9120 - Server List - ------------------------------ - server-address key - ------------------------------ - 127.0.0.1:5000 server:android - ------------------------------ - - Queue Status - ------------------------------- - key total free pending - ------------------------------- - android 1 1 0 - ------------------------------- +Please refer to the tutorial +`How To Deploy model on Adreno `_ +for manual RPC environment setup. -This concludes RPC Setup and we have rpc-tracker available on host 127.0.0.1 (rpc-tracker) and port 9120 (rpc-port). +This concludes RPC Setup and we have rpc-tracker available on host 127.0.0.1 (rpc-tracker) and port 9190 (rpc-port). .. _commandline_interface: @@ -431,7 +390,7 @@ Here we use a model from Keras and it uses RPC setup for tuning and finally gene resnet50.h5 -o \ keras-resnet50.log \ --early-stopping 0 --repeat 30 --rpc-key android \ - --rpc-tracker 127.0.0.1:9120 --trials 1024 \ + --rpc-tracker 127.0.0.1:9190 --trials 1024 \ --tuning-records keras-resnet50-records.log --tuner xgb **Model Compilation:** @@ -466,7 +425,7 @@ We can use below tvmc command to deploy on remore target via RPC based setup. :: python3 -m tvm.driver.tvmc run --device="cl" keras-resnet50.tar \ - --rpc-key android --rpc-tracker 127.0.0.1:9120 --print-time + --rpc-key android --rpc-tracker 127.0.0.1:9190 --print-time tvmc based run has more option to initialize the input in various modes line fill, random ..etc. @@ -628,4 +587,4 @@ We then can compile our model in any convinient way ) .. |High-level overview of the Adreno™ A5x architecture for OpenCL| image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/how-to/adreno_architecture.png -.. |Android deployment pipeline| image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/how-to/android_deployment_pipeline.jpg +.. |Adreno deployment pipeline| image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/how-to/Adreno-Deployment-Pipeline.jpg diff --git a/gallery/how_to/deploy_models/deploy_model_on_adreno.py b/gallery/how_to/deploy_models/deploy_model_on_adreno.py index db2f78f69417..657374ebfe1e 100644 --- a/gallery/how_to/deploy_models/deploy_model_on_adreno.py +++ b/gallery/how_to/deploy_models/deploy_model_on_adreno.py @@ -53,11 +53,17 @@ # # adb devices # +# Set the android device to use +# +# .. code-block:: bash +# +# export ANDROID_SERIAL= +# # Then to upload these two files to the device you should use: # # .. code-block:: bash # -# adb -s push {libtvm_runtime.so,tvm_rpc} /data/local/tmp +# adb push {libtvm_runtime.so,tvm_rpc} /data/local/tmp # # At this moment you will have «libtvm_runtime.so» and «tvm_rpc» on path /data/local/tmp on your device. # Sometimes cmake can’t find «libc++_shared.so». Use: @@ -70,7 +76,7 @@ # # .. code-block:: bash # -# adb -s push libc++_shared.so /data/local/tmp +# adb push libc++_shared.so /data/local/tmp # # We are now ready to run the TVM RPC Server. # Launch rpc_tracker with following line in 1st console: @@ -83,12 +89,12 @@ # # .. code-block:: bash # -# adb -s reverse tcp:9190 tcp:9190 -# adb -s forward tcp:9090 tcp:9090 -# adb -s forward tcp:9091 tcp:9091 -# adb -s forward tcp:9092 tcp:9092 -# adb -s forward tcp:9093 tcp:9093 -# adb -s shell LD_LIBRARY_PATH=/data/local/tmp /data/local/tmp/tvm_rpc server --host=0.0.0.0 --port=9090 --tracker=127.0.0.1:9190 --key=android --port-end=9190 +# adb reverse tcp:9190 tcp:9190 +# adb forward tcp:5000 tcp:5000 +# adb forward tcp:5002 tcp:5001 +# adb forward tcp:5003 tcp:5002 +# adb forward tcp:5004 tcp:5003 +# adb shell LD_LIBRARY_PATH=/data/local/tmp /data/local/tmp/tvm_rpc server --host=0.0.0.0 --port=5000 --tracker=127.0.0.1:9190 --key=android --port-end=5100 # # Before proceeding to compile and infer model, specify TVM_TRACKER_HOST and TVM_TRACKER_PORT # @@ -130,6 +136,10 @@ from tvm.relay.op.contrib import clml from tvm import autotvm +# Below are set of configuration that controls the behaviour of this script like +# local run or device run, target definitions, dtype setting and auto tuning enablement. +# Change these settings as needed if required. + # Adreno devices are efficient with float16 compared to float32 # Given the expected output doesn't effect by lowering precision # it's advisable to use lower precision. @@ -156,7 +166,8 @@ arch = "arm64" target = tvm.target.Target("llvm -mtriple=%s-linux-android" % arch) -# Auto tuning is compute and time taking task, hence disabling for default run. Please enable it if required. +# Auto tuning is compute intensive and time taking task, +# hence disabling for default run. Please enable it if required. is_tuning = False tune_log = "adreno-resnet18.log" @@ -220,6 +231,19 @@ ################################################################# # Precisions # ---------- + +# Adreno devices are efficient with float16 compared to float32 +# Given the expected output doesn't effect by lowering precision +# it's advisable to use lower precision. + +# TVM support Mixed Precision through ToMixedPrecision transformation pass. +# We may need to register precision rules like precision type, accumultation +# datatype ...etc. for the required operators to override the default settings. +# The below helper api simplifies the precision conversions across the module. +# Now it supports dtypes "float16" and "float16_acc32". + +# dtype is set to "float16_acc32" in configuration section above. + from tvm.relay.op.contrib import adreno adreno.convert_to_dtype(mod["main"], dtype) @@ -236,6 +260,12 @@ # Prepare TVM Target # ------------------ +# This generated example running on our x86 server for demonstration. + +# To deply and tun on real target over RPC please set :code:`local_demo` to False in above configuration sestion. +# Also, :code:`test_target` is set to :code:`llvm` as this example to make compatible for x86 demonstration. +# Please change it to :code:`opencl` or :code:`opencl -device=adreno` for RPC target in configuration above. + if local_demo: target = tvm.target.Target("llvm") elif test_target.find("opencl"): @@ -254,6 +284,10 @@ rpc_tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190)) key = "android" +# Auto tuning is compute intensive and time taking task. +# It is set to False in above configuration as this script runs in x86 for demonstration. +# Please to set :code:`is_tuning` to True to enable auto tuning. + if is_tuning: # Auto Tuning Stage 1: Extract tunable tasks tasks = autotvm.task.extract_from_program( @@ -275,9 +309,9 @@ ), ) n_trial = 1024 # Number of iteration of training before choosing the best kernel config - early_stopping = False # Do we apply early stopping when the loss is not minimizing + early_stopping = False # Can be enabled to stop tuning while the loss is not minimizing. - # Iterate through each task and call the tuner + # Auto Tuning Stage 3: Iterate through the tasks and tune. from tvm.autotvm.tuner import XGBTuner for i, tsk in enumerate(reversed(tasks[:3])): @@ -295,7 +329,7 @@ autotvm.callback.log_to_file(tmp_log_file), ], ) - # Pick the best performing kerl configurations from the overall log. + # Auto Tuning Stage 4: Pick the best performing configurations from the overall log. autotvm.record.pick_best(tmp_log_file, tune_log) ################################################################# @@ -303,6 +337,9 @@ # -------------------------- # OpenCLML offloading will try to accelerate supported operators # by using OpenCLML proprietory operator library. + +# By default :code:`enable_clml` is set to False in above configuration section. + if not local_demo and enable_clml: mod = clml.partition_for_clml(mod, params) diff --git a/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py b/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py index 973663d1f20c..143677502121 100644 --- a/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py +++ b/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py @@ -65,7 +65,10 @@ # To enable OpenCLML accelerated operator library. enable_clml = False -cross_compiler = "/opt/android-sdk-linux/ndk/21.3.6528147/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" +cross_compiler = ( + os.environ["ANDROID_NDK_HOME"] + + "/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" +) ####################################################################### # Make a Keras Resnet50 Model @@ -104,6 +107,12 @@ rpc_key = "android" rpc_tracker = rpc_tracker_host + ":" + str(rpc_tracker_port) +# Auto tuning is compute intensive and time taking task. +# It is set to False in above configuration as this script runs in x86 for demonstration. +# Please to set :code:`is_tuning` to True to enable auto tuning. + +# Also, :code:`test_target` is set to :code:`llvm` as this example to make compatible for x86 demonstration. +# Please change it to :code:`opencl` or :code:`opencl -device=adreno` for RPC target in configuration above. if is_tuning: tvmc.tune( @@ -125,6 +134,11 @@ # ----------- # Compilation to produce tvm artifacts +# This generated example running on our x86 server for demonstration. +# To deply and tun on real target over RPC please set :code:`local_demo` to False in above configuration sestion. + +# OpenCLML offloading will try to accelerate supported operators by using OpenCLML proprietory operator library. +# By default :code:`enable_clml` is set to False in above configuration section. if not enable_clml: if local_demo: diff --git a/tests/scripts/task_build_adreno_bins.sh b/tests/scripts/task_build_adreno_bins.sh index f65794106ee3..87f50367440c 100755 --- a/tests/scripts/task_build_adreno_bins.sh +++ b/tests/scripts/task_build_adreno_bins.sh @@ -28,7 +28,6 @@ cd ${output_directory} cp ../cmake/config.cmake . -echo set\(USE_MICRO OFF\) >> config.cmake if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then echo set\(USE_CLML "${ADRENO_OPENCL}"\) >> config.cmake echo set\(USE_CLML_GRAPH_EXECUTOR "${ADRENO_OPENCL}"\) >> config.cmake diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh index d378b5f842b5..62e6ffecbced 100755 --- a/tests/scripts/task_config_build_adreno.sh +++ b/tests/scripts/task_config_build_adreno.sh @@ -23,7 +23,6 @@ mkdir -p "$BUILD_DIR" cd "$BUILD_DIR" cp ../cmake/config.cmake . -echo set\(USE_OPENCL ON\) >> config.cmake if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake fi From 585f5f633b4dbd66c5bbe1c53ebae32f7925ae6f Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 14 Feb 2023 14:20:59 +0530 Subject: [PATCH 07/20] Update docs/how_to/deploy/adreno.rst Co-authored-by: Egor Churaev --- docs/how_to/deploy/adreno.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index b0c0bb8822e9..92d1db482589 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -462,7 +462,7 @@ for a step by step explanation of the same. .. _application_integration: -Aplication Integration: +Application Integration: ----------------------- TVM compilation output is represented as module shared lib (mod.so), graph json(mod.json) and params (mod.params). From e4417a040df0572a9eb06be3e03b5fb3143b115c Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 14 Feb 2023 14:41:59 +0530 Subject: [PATCH 08/20] Update gallery/how_to/deploy_models/deploy_model_on_adreno.py Co-authored-by: Egor Churaev --- gallery/how_to/deploy_models/deploy_model_on_adreno.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/how_to/deploy_models/deploy_model_on_adreno.py b/gallery/how_to/deploy_models/deploy_model_on_adreno.py index 657374ebfe1e..559454f7c442 100644 --- a/gallery/how_to/deploy_models/deploy_model_on_adreno.py +++ b/gallery/how_to/deploy_models/deploy_model_on_adreno.py @@ -53,7 +53,7 @@ # # adb devices # -# Set the android device to use +# Set the android device to use, if you have several devices connected to your computer. # # .. code-block:: bash # From ad82f8c515bdfd446cf2e3e6cd9e799bb3fa8e8e Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Tue, 14 Feb 2023 14:55:36 +0530 Subject: [PATCH 09/20] * review --- docs/how_to/deploy/adreno.rst | 85 +++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 39 deletions(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index 92d1db482589..29784fbf16d0 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -107,7 +107,7 @@ About OpenCLML -------------- OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning operators. -These operators are exposed as an extension "cl_qcom_ml_ops" to standard OpenCL specification. +These operators are exposed as an extension ``cl_qcom_ml_ops`` to standard OpenCL specification. Please refer `Accelerate your models with our OpenCL ML SDK `_ for more details. OpenCLML is integrated into TVM as a `BYOC `_ solution. @@ -213,7 +213,7 @@ On successful compilation this leaves us into a docker shell. The build leaves t * rtvm : A native stand alone tool While using docker environment the android device is shared with host. Hence, it is required -to have adb version "1.0.41" on the host as the docker used the same version. +to have adb version ``1.0.41`` on the host as the docker used the same version. We can check adb devices availability inside docker environment too. @@ -239,9 +239,13 @@ Below command will configure the build the host compiler cd build cp ../cmake/config.cmake . + # Enable RPC capability to communicate to remote device. echo set\(USE_RPC ON\) >> config.cmake + # We use graph executor for any host(x86) side verification of the model. echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake + # Enable backtrace if possible for more ebug information on any crash. echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake + # The target_host will be llvm. echo set\(USE_LLVM ON\) >> config.cmake Additionally we can push below config entry to compile with OpenCLML support. @@ -278,14 +282,22 @@ Target build require Android NDK to be installed. mkdir -p build-adreno cd build-adreno cp ../cmake/config.cmake . + # Enable OpenCL backend. echo set\(USE_OPENCL ON\) >> config.cmake + # Enable RPC functionality. echo set\(USE_RPC ON\) >> config.cmake + # Build tvm_rpc tool that runs on target device. echo set\(USE_CPP_RPC ON\) >> config.cmake + # Build native rtvm deploy tool. echo set\(USE_CPP_RTVM ON\) >> config.cmake + # We use graph executor for deploying on devices like Android. echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake + # Backtrace enablement if possible. echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake + # Adreno supports 32bit alignment for OpenCL allocations rather 64bit. echo set\(USE_KALLOC_ALIGNMENT 32\) >> config.cmake + # Android build related defines. echo set\(ANDROID_ABI arm64-v8a\) >> config.cmake echo set\(ANDROID_PLATFORM android-28\) >> config.cmake echo set\(MACHINE_NAME aarch64-linux-gnu\) >> config.cmake @@ -298,7 +310,7 @@ Additionally we can push below config to compile with OpenCLML support. echo set\(USE_CLML "${ADRENO_OPENCL}"\) >> config.cmake echo set\(USE_CLML_GRAPH_EXECUTOR "${ADRENO_OPENCL}"\) >> config.cmake -For Android target build ANDROID_NDK_HOME is a dependency and we should have the same in the enviromnet variable. +For Android target build ``ANDROID_NDK_HOME`` is a dependency and we should have the same in the enviromnet variable. Below commands will build Adreno™ target components :: @@ -326,7 +338,7 @@ RPC Setup allows remote target access over TCP/IP networking interface. RPC Setu involves running of auto generated kernels on real device and optimize the same by using machine learning approach. Please refer `Auto-Tune with Templates and AutoTVM `_ got more details about AutoTVM. -RPC Setup is also useful to deply the compiled model to a remote device from python interface or ```tvmc``` tool from host device. +RPC Setup is also useful to deply the compiled model to a remote device from python interface or ``tvmc`` tool from host device. RPC Setup has multiple components as listed below. @@ -352,7 +364,7 @@ Below command launches tracker in docker environment, where tracker listens on p ./tests/scripts/ci.py adreno -i # Launch a new shell on the anreno docker source tests/scripts/setup-adreno-env.sh -e tracker -p 9190 -Now, the below comand can run TVM RPC on remote android device with id "abcdefgh". +Now, the below comand can run TVM RPC on remote android device with id ``abcdefgh``. :: @@ -367,7 +379,7 @@ Please refer to the tutorial `How To Deploy model on Adreno `_ for manual RPC environment setup. -This concludes RPC Setup and we have rpc-tracker available on host 127.0.0.1 (rpc-tracker) and port 9190 (rpc-port). +This concludes RPC Setup and we have rpc-tracker available on host ``127.0.0.1`` (rpc-tracker) and port ``9190`` (rpc-port). .. _commandline_interface: @@ -375,13 +387,15 @@ This concludes RPC Setup and we have rpc-tracker available on host 127.0.0.1 (rp Commandline Tools ----------------- -Here we describe entire compilation process using command line tools. TVM has command line utility "tvmc" to perform -model import, auto tuning, compilation and deply over rpc. "tvmc" has many options to explore and try. +Here we describe entire compilation process using command line tools. TVM has command line utility +`tvmc `_ to perform +model import, auto tuning, compilation and deply over rpc. +`tvmc `_ has many options to explore and try. **Model Import & Tuning:** Use the below command to import a model from any framework and auto tune the same. Here we use a model from Keras and it uses RPC setup for tuning and finally generates tuning log file -"keras-resnet50.log". +``keras-resnet50.log``. :: @@ -404,7 +418,7 @@ Use below command for compiling the model and produce TVM compiler outputs. --target="opencl, llvm" --target-llvm-mtriple aarch64-linux-gnu --target-opencl-device adreno \ --tuning-records keras-resnet50.log -o keras-resnet50.tar resnet50.h5 -While enabled OpenCLML offloading we nee dto add target "clml" as shown below. Tuning log is valid for OpenCLML offloading also +While enabled OpenCLML offloading we need to add target ``clml`` as shown below. Tuning log is valid for OpenCLML offloading also as the OpenCL path is fallback option for any operator didn't go through OpenCLML path. The tuning log will be used for such operators. :: @@ -414,7 +428,8 @@ as the OpenCL path is fallback option for any operator didn't go through OpenCLM --target="opencl, clml, llvm" --target-llvm-mtriple aarch64-linux-gnu --target-opencl-device adreno \ --tuning-records keras-resnet50.log -o keras-resnet50.tar resnet50.h5 -On success ful compilation above commands produce "keras-resnet50.tar". It is a compressed archive with kernel shared lib, graph json and params binary. +On successful compilation, above command produce ``keras-resnet50.tar``. +It is a compressed archive with kernel shared lib(mod.so), graph json(mod.json) and params binary(mod.params). **Deploy & Run on Target:** @@ -427,10 +442,11 @@ We can use below tvmc command to deploy on remore target via RPC based setup. python3 -m tvm.driver.tvmc run --device="cl" keras-resnet50.tar \ --rpc-key android --rpc-tracker 127.0.0.1:9190 --print-time -tvmc based run has more option to initialize the input in various modes line fill, random ..etc. +`tvmc `_ based run has more option +to initialize the input in various modes like fill, random ..etc. -TVM also supports "rtvm" tool to run the model narively on ADB shell. The build process produced this tool under build-adreno-target. +TVM also supports ``rtvm`` tool to run the model narively on ADB shell. The build process produced this tool under build-adreno-target. Please refer to `rtvm `_ for more details about this tool. @@ -447,9 +463,10 @@ to a relay module. Relay module will be used across the auto tuning, compilation TVMC interface can be accessed as shown below to import, compile and run a model. Please refer to the tutorial for the same `How To Deploy model on Adreno using TVMC `_ -tvmc compiled package can be used for native deploy also using "rtvm" utility. +tvmc compiled package can be used for native deploy also using ``rtvm`` utility. -Also, please refer to tvmc documentation for more details about the api interface. +Also, please refer to `tvmc `_ + documentation for more details about the api interface. **Relay Interface:** @@ -459,10 +476,9 @@ Relay interface follows tvmc kind of a flow where we produce TVM module first fo Please refer to the tutorial `How To Deploy model on Adreno `_ for a step by step explanation of the same. - .. _application_integration: -Application Integration: +Application Integration ----------------------- TVM compilation output is represented as module shared lib (mod.so), graph json(mod.json) and params (mod.params). @@ -470,17 +486,17 @@ Archived representation of TVMPackage is also contains the same. In general a CPP/C based interface will be sufficient for any Android application integration. -TVM natively expose c_runtime_api for loading a TVM compiled module and run the same. +TVM natively expose ``c_runtime_api`` for loading a TVM compiled module and run the same. Alternatively one may refer to `cpp_rtvm `_ -tvm_runner interface too for further simplified version of the same. +``TVMRunner`` interface too for further simplified version of the same. .. _advanced_usage: -Advanced Usage: ---------------- +Advanced Usage +-------------- This section details some of the advanced usage and additional information while using Adreno™ target on TVM. @@ -494,9 +510,9 @@ Below snippet can dump CLML sub graphs in json format. .. code:: python - # Look for "clml" typed module impoted. + # Look for "clml" typed module imported. clml_modules = list(filter(lambda mod: mod.type_key == "clml", lib.get_lib().imported_modules)) - # Loop throught all clml sub graphs and dump the json formatted CLML sub graphs. + # Loop through all clml sub graphs and dump the json formatted CLML sub graphs. for cmod in clml_modules: print("CLML Src:", cmod.get_source()) @@ -506,23 +522,13 @@ Similarly, below snippet can extract opencl kernel source from the compiled TVM .. code:: python # Similarly we can dump open kernel source too as shown below - # Look for "opencl" typed module impoted. + # Look for "opencl" typed module imported. opencl_modules = list(filter(lambda mod: mod.type_key == "opencl", lib.get_lib().imported_modules)) - # Now dump open cource for each opencl targetted sub graph. + # Now dump kernel source for each OpenCL targetted sub graph. for omod in opencl_modules: print("OpenCL Src:", omod.get_source()) -Inspecting above code for target device "opencl --device=adreno" shows texture usage (image2d_t) as shown below. - -.. code:: c - - __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t pad_temp_global_texture, __read_only image2d_t p0) { - // body.. - -*image2d_t* is a built-in OpenCL types that represents two-dimensional image object and provides several additional functions. -When we use *image2d_t* we read *4 elements at one time*, and it helps to utilize hardware in a more efficient way. - Precisions ~~~~~~~~~~ The right choice of precision for a specific workload can greatly increase the efficiency of the solution, @@ -536,13 +542,14 @@ To leverage the GPU hardware capabilities and utilize the benefits of half preci we can convert an original model having floating points operation to a model operating with half precision. Choosing lower precision will positively affect the performance of the model, but it may also have a decrease in the accuracy of the model. -To do the conversion you need to call adreno specific transformation API as soon relay module is generated through any frontend: +To do the conversion you need to call adreno specific transformation API as soon as relay module is generated through any frontend. .. code:: python from tvm.relay.op.contrib import adreno adreno.convert_to_dtype(mod["main"], "float16") +``tvm.relay.op.contrib.adreno.convert_to_dtype`` is simplified API over ``ToMixedPrecision`` pass to get desired precision. We then can compile our model in any convinient way @@ -556,10 +563,10 @@ We then can compile our model in any convinient way **float16_acc32 (Mixed Precision)** -ToMixedPrecision pass traverse over the network and split network to clusters of ops dealing with float or float16 data types. +``ToMixedPrecision`` pass traverse over the network and split network to clusters of ops dealing with float or float16 data types. The clusters are defined by three types of operations: - Operations always be converted into float16 data type -- Operations which can be converted if they follow by converted cluster +- Operations which can be converted if they followed by converted cluster - Operations never be converted to the float16 data type This list is defined in the ToMixedPrecision implementation here `relay/transform/mixed_precision.py `_ @@ -569,7 +576,7 @@ The ``ToMixedPrecision`` method is a pass to convert an FP32 relay graph into an FP16 or FP32 accumulation dtypes). Doing this transformation is useful for reducing model size as it halves the expected size of the weights (FP16_acc16 case). -ToMixedPrecision pass usage is simplified into a simple call as shown below for usage. +``ToMixedPrecision`` pass usage is simplified into a simple call as shown below for usage. .. code:: python From 9cf4dd5ff42a6aa04deb19c030ada7e9b68363a7 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Wed, 15 Feb 2023 09:42:29 +0530 Subject: [PATCH 10/20] * ci error --- gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py b/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py index 143677502121..b54ac1b2c6e7 100644 --- a/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py +++ b/gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py @@ -66,7 +66,7 @@ # To enable OpenCLML accelerated operator library. enable_clml = False cross_compiler = ( - os.environ["ANDROID_NDK_HOME"] + os.getenv("ANDROID_NDK_HOME", "") + "/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" ) From 233e157730a287d63664d3975ec94ce21c273b0d Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Thu, 16 Feb 2023 11:21:04 +0530 Subject: [PATCH 11/20] * review --- docs/how_to/deploy/adreno.rst | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index 29784fbf16d0..c51663f63beb 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -453,17 +453,18 @@ Please refer to `rtvm `_ .. _python_interface: This section explains importing, auto tuning, compiling and running a model using python interface.\ -TVM has a high level interface through tvmc abstraction as well as relay api. We will discuss about both of these in details. - -Unlike command line interface python interface starts with model importing. Model importing converts the models from any framework -to a relay module. Relay module will be used across the auto tuning, compilation stages. +TVM has a high level interface through ``tvmc`` abstraction as well as low level relay api. We will discuss about both of these in details. **TVMC Interface:** -TVMC interface can be accessed as shown below to import, compile and run a model. Please refer to the tutorial for the same +While using ``tvmc`` python interface we first load a model that produces ``TVMCModel``. ``TVMCModel`` will be used for Auto Tuning to produce tuning cache. +Compilation process uses ``TVMCModel`` and tuning cache (optional) to produce ``TVMCPackage``. Now, ``TVMCPackage`` will be saved to file system or +can be used to deploy and run on target device. + +Please refer to the tutorial for the same `How To Deploy model on Adreno using TVMC `_ -tvmc compiled package can be used for native deploy also using ``rtvm`` utility. +Saved ``TVMCPackage`` can be used for native deployment using ``rtvm`` utility too. Also, please refer to `tvmc `_ documentation for more details about the api interface. @@ -471,7 +472,11 @@ Also, please refer to `tvmc `_ for a step by step explanation of the same. From e71cb76b6bbfeab77353b9c87d3b376401d5831f Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Fri, 17 Feb 2023 11:59:33 +0530 Subject: [PATCH 12/20] * tvmc precision options --- docs/how_to/deploy/adreno.rst | 7 +++++++ python/tvm/relay/op/contrib/adreno.py | 15 ++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index c51663f63beb..0ea3203a737d 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -565,6 +565,9 @@ We then can compile our model in any convinient way mod, target_host=target_host, target=target, params=params ) +While using ``tvmc`` python interface, argument ``pre_build_hooks=[adreno.mixed_precision_fp16]`` enables precision conversion to float16. +Similarly, ``tvmc`` command line interface option ``--pre-build-hooks "adreno.mixed_precision_fp16]"`` does the same. + **float16_acc32 (Mixed Precision)** @@ -598,5 +601,9 @@ We then can compile our model in any convinient way mod, target_host=target_host, target=target, params=params ) +While using ``tvmc`` python interface, argument ``pre_build_hooks=[adreno.mixed_precision_fp16_acc32]`` enables precision conversion to float16. +Similarly, ``tvmc`` command line interface option ``--pre-build-hooks "adreno.mixed_precision_fp16_acc32]"`` does the same. + + .. |High-level overview of the Adreno™ A5x architecture for OpenCL| image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/how-to/adreno_architecture.png .. |Adreno deployment pipeline| image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/how-to/Adreno-Deployment-Pipeline.jpg diff --git a/python/tvm/relay/op/contrib/adreno.py b/python/tvm/relay/op/contrib/adreno.py index 965b9b387c9a..ddec2fdf3670 100644 --- a/python/tvm/relay/op/contrib/adreno.py +++ b/python/tvm/relay/op/contrib/adreno.py @@ -38,7 +38,6 @@ class AdrenoMixedPrecision(object): def __init__(self): """Saves the required info for RAII pattern usage. - Parameters ---------- acc_dtype : atr @@ -83,3 +82,17 @@ def convert_to_dtype(mod, dtype): else: print("Warn: Invald dtype conversion to ", dtype) return mod + + +@tvm.register_func("adreno.mixed_precision_fp16") +def mixed_precision_hook_fp16(mod, params): + """TVMC hook api""" + + return convert_to_dtype(mod["main"], "float16") + + +@tvm.register_func("adreno.mixed_precision_fp16_acc32") +def mixed_precision_hook_fp16_acc32(mod, params): + """TVMC hook api""" + + return convert_to_dtype(mod["main"], "float16_acc32") From 68874a1e64a297e141abc243f87c9f4a2230d4e2 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Fri, 17 Feb 2023 14:12:59 +0530 Subject: [PATCH 13/20] * Deploy section enhanced according to pipe line diagram. --- docs/how_to/deploy/adreno.rst | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index 0ea3203a737d..6ed7715e7001 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -442,13 +442,21 @@ We can use below tvmc command to deploy on remore target via RPC based setup. python3 -m tvm.driver.tvmc run --device="cl" keras-resnet50.tar \ --rpc-key android --rpc-tracker 127.0.0.1:9190 --print-time -`tvmc `_ based run has more option +`tvmc `_ based run has more options to initialize the input in various modes like fill, random ..etc. +``tvmc`` based deployment generally a quick verification of compiled model on target from remote host via RPC setup. -TVM also supports ``rtvm`` tool to run the model narively on ADB shell. The build process produced this tool under build-adreno-target. +Production generally uses native deploymenmt environment like Android JNI or CPP native environments. +Here we need to use cross compiled ``tvm_runtime`` interface to deploy the tvm compilation output, i.e. ``TVMPackage``. + +TVM has a standalone tool ``rtvm`` to deploy and run the model natively on ADB shell. The build process produces this tool under build-adreno-target. Please refer to `rtvm `_ for more details about this tool. +While integrating inside existing Android application TVM has multiple options. For JNI or CPP native we may use `C Runtime API `_ +You may refer to ``rtvm``'s simplified interface `TVMRunner `_ also. + +Additionally, TVM also supports Java interface through `TVM4J `_ .. _python_interface: From b2bb23b3cabdfab7b28136e1ecb6bd6f01a2d3a1 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Thu, 23 Feb 2023 11:44:47 +0530 Subject: [PATCH 14/20] * env setup helper --- tests/scripts/setup-adreno-env.sh | 104 ++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100755 tests/scripts/setup-adreno-env.sh diff --git a/tests/scripts/setup-adreno-env.sh b/tests/scripts/setup-adreno-env.sh new file mode 100755 index 000000000000..01b8d78a9eed --- /dev/null +++ b/tests/scripts/setup-adreno-env.sh @@ -0,0 +1,104 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +ENVIRONMENT="default" +RPC_PORT="" +ADB_SERIAL="" + +while [[ $# -gt 0 ]]; do + case $1 in + -e|--environment) + ENVIRONMENT="$2" + shift # past argument + shift # past value + ;; + -p|--rpc-port) + RPC_PORT="$2" + shift # past argument + shift # past value + ;; + -d|--android-device) + ADB_SERIAL="$2" + shift # past argument + shift # past value + ;; + -*|--*) + echo "Unknown option $1" + echo "Usage: source setup-adreno-env.sh -e -p -d " + return 1 + ;; + *) + ;; + esac +done + +echo "ENVIRONMENT = ${ENVIRONMENT}" +echo "RPC_PORT = ${RPC_PORT}" +echo "ADB_SERIAL = ${ADB_SERIAL}" + + +function def_environment() { + source tests/scripts/setup-pytest-env.sh + export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python + export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" + export TVM_TRACKER_HOST=127.0.0.1 + export TVM_TRACKER_PORT=$RPC_PORT + export RPC_DEVICE_KEY="android" + export RPC_TARGET="adreno" + export TVM_NDK_CC="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" +} + +def_environment + +case ${ENVIRONMENT} in + + "tracker") + echo "Starting Tracker on port :${TVM_TRACKER_PORT}" + def_environment + python3 -m tvm.exec.rpc_tracker --host "${TVM_TRACKER_HOST}" --port "${TVM_TRACKER_PORT}" + ;; + + "device") + echo "Running RPC on device : ${ADB_SERIAL} with key $RPC_DEVICE_KEY" + def_environment + export ANDROID_SERIAL=${ADB_SERIAL} + + adb shell "mkdir -p /data/local/tmp/tvm_ci" + adb push build-adreno-target/tvm_rpc /data/local/tmp/tvm_ci/tvm_rpc_ci + adb push build-adreno-target/libtvm_runtime.so /data/local/tmp/tvm_ci + + adb reverse tcp:${TVM_TRACKER_PORT} tcp:${TVM_TRACKER_PORT} + adb forward tcp:5000 tcp:5000 + adb forward tcp:5001 tcp:5001 + adb forward tcp:5002 tcp:5002 + adb shell "cd /data/local/tmp/tvm_ci; killall -9 tvm_rpc_ci; sleep 2; LD_LIBRARY_PATH=/data/local/tmp/tvm_ci/ ./tvm_rpc_ci server --host=0.0.0.0 --port=5000 --port-end=5010 --tracker=127.0.0.1:${TVM_TRACKER_PORT} --key=${RPC_DEVICE_KEY}" + ;; + + "default") + def_environment + echo "Setting dev environment with Tracker Port : $TVM_TRACKER_HOST} and the available devices are" + python3 -m tvm.exec.query_rpc_tracker --port ${TVM_TRACKER_PORT} + ;; + + *) + echo "Unknown environment $ENVIRONMENT" + echo "Usage: source setup-adreno-env.sh -e -p -d " + return 1 + ;; +esac From f6e9a4d77df7ef075bad77ce038a2326b69fa385 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Mon, 6 Mar 2023 09:12:40 +0530 Subject: [PATCH 15/20] * changes according to tvmc enhancements. --- docs/how_to/deploy/adreno.rst | 71 +++++++++++--- .../deploy_models/deploy_model_on_adreno.py | 24 +++-- python/tvm/relay/op/contrib/__init__.py | 1 - python/tvm/relay/op/contrib/adreno.py | 98 ------------------- .../relay/opencl_texture/test_network.py | 24 +++-- 5 files changed, 92 insertions(+), 126 deletions(-) delete mode 100644 python/tvm/relay/op/contrib/adreno.py diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index 6ed7715e7001..a88db8ccb44b 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -559,12 +559,21 @@ To do the conversion you need to call adreno specific transformation API as soon .. code:: python - from tvm.relay.op.contrib import adreno - adreno.convert_to_dtype(mod["main"], "float16") + from tvm.driver.tvmc.transform import apply_graph_transforms + mod = apply_graph_transforms( + mod, + { + "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d", "nn.dense"], + "mixed_precision_calculation_type": "float16", + "mixed_precision_acc_type": "float16", + }, + ) -``tvm.relay.op.contrib.adreno.convert_to_dtype`` is simplified API over ``ToMixedPrecision`` pass to get desired precision. -We then can compile our model in any convinient way +``tvm.driver.tvmc.transform.apply_graph_transforms`` is simplified API over ``ToMixedPrecision`` pass to get desired precision. + +We can then compile our model in any convinient way .. code:: python @@ -573,8 +582,23 @@ We then can compile our model in any convinient way mod, target_host=target_host, target=target, params=params ) -While using ``tvmc`` python interface, argument ``pre_build_hooks=[adreno.mixed_precision_fp16]`` enables precision conversion to float16. -Similarly, ``tvmc`` command line interface option ``--pre-build-hooks "adreno.mixed_precision_fp16]"`` does the same. +While using ``tvmc`` python interface, the below arguments enables precision conversion to float16. + +.. code:: python + + mixed_precision = True, + mixed_precision_ops = ["nn.conv2d", "nn.dense"], + mixed_precision_calculation_type = "float16", + mixed_precision_acc_type = "float16" + +Similarly, ``tvmc`` command line interface option bas below listed options. + +.. code:: bash + + --mixed-precision + --mixed-precision-ops nn.conv2d nn.dense + --mixed-precision-calculation-type float16 + --mixed-precision-acc-type float16 **float16_acc32 (Mixed Precision)** @@ -596,11 +620,21 @@ as it halves the expected size of the weights (FP16_acc16 case). .. code:: python - from tvm.relay.op.contrib import adreno - adreno.convert_to_dtype(mod["main"], "float16_acc32") + from tvm.driver.tvmc.transform import apply_graph_transforms + mod = apply_graph_transforms( + mod, + { + "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d", "nn.dense"], + "mixed_precision_calculation_type": "float16", + "mixed_precision_acc_type": "float32", + }, + ) -We then can compile our model in any convinient way +``tvm.driver.tvmc.transform.apply_graph_transforms`` is simplified API over ``ToMixedPrecision`` pass to get desired precision. + +We can then compile our model in any convinient way .. code:: python @@ -609,8 +643,23 @@ We then can compile our model in any convinient way mod, target_host=target_host, target=target, params=params ) -While using ``tvmc`` python interface, argument ``pre_build_hooks=[adreno.mixed_precision_fp16_acc32]`` enables precision conversion to float16. -Similarly, ``tvmc`` command line interface option ``--pre-build-hooks "adreno.mixed_precision_fp16_acc32]"`` does the same. +While using ``tvmc`` python interface, the below arguments enables precision conversion to float16. + +.. code:: python + + mixed_precision = True, + mixed_precision_ops = ["nn.conv2d", "nn.dense"], + mixed_precision_calculation_type = "float16", + mixed_precision_acc_type = "float32" + +Similarly, ``tvmc`` command line interface option bas below listed options. + +.. code:: bash + + --mixed-precision + --mixed-precision-ops nn.conv2d nn.dense + --mixed-precision-calculation-type float16 + --mixed-precision-acc-type float32 .. |High-level overview of the Adreno™ A5x architecture for OpenCL| image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/how-to/adreno_architecture.png diff --git a/gallery/how_to/deploy_models/deploy_model_on_adreno.py b/gallery/how_to/deploy_models/deploy_model_on_adreno.py index 559454f7c442..c2ba189a6715 100644 --- a/gallery/how_to/deploy_models/deploy_model_on_adreno.py +++ b/gallery/how_to/deploy_models/deploy_model_on_adreno.py @@ -145,9 +145,10 @@ # it's advisable to use lower precision. # We have a helper API to make the precision conversion simple and # it supports dtype with "float16" and "float16_acc32" modes. -# Let's choose "float16_acc32" for this example. +# Let's choose "float16" for calculation and "float32" for accumulation. -dtype = "float16_acc32" +calculation_dtype = "float16" +acc_dtype = "float32" # Specify Adreno target before compiling to generate texture # leveraging kernels and get all the benefits of textures @@ -240,16 +241,21 @@ # We may need to register precision rules like precision type, accumultation # datatype ...etc. for the required operators to override the default settings. # The below helper api simplifies the precision conversions across the module. -# Now it supports dtypes "float16" and "float16_acc32". -# dtype is set to "float16_acc32" in configuration section above. +# Calculation dtype is set to "float16" and accumulation dtype is set to "float32" +# in configuration section above. -from tvm.relay.op.contrib import adreno +from tvm.driver.tvmc.transform import apply_graph_transforms -adreno.convert_to_dtype(mod["main"], dtype) - -dtype = "float32" if dtype == "float32" else "float16" -print(mod) +mod = apply_graph_transforms( + mod, + { + "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d", "nn.dense"], + "mixed_precision_calculation_type": calculation_dtype, + "mixed_precision_acc_type": acc_dtype, + }, +) ################################################################# # As you can see in the IR, the architecture now contains cast operations, which are diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 104a48cb54bb..01708e8452bd 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -27,4 +27,3 @@ from .tensorrt import * from .cutlass import * from .clml import * -from .adreno import * diff --git a/python/tvm/relay/op/contrib/adreno.py b/python/tvm/relay/op/contrib/adreno.py deleted file mode 100644 index ddec2fdf3670..000000000000 --- a/python/tvm/relay/op/contrib/adreno.py +++ /dev/null @@ -1,98 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-argument -"""Adreno specific helpers.""" -import tvm - -from tvm import relay -from tvm.ir import IRModule - -acc_dtype = "float32" - - -def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str): - global acc_dtype - return [ - relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS, - acc_dtype, - mixed_precision_type, - ] - - -class AdrenoMixedPrecision(object): - """Temporarily changes attr of ops to enable FP32 accumulation .""" - - def __init__(self): - """Saves the required info for RAII pattern usage. - Parameters - ---------- - acc_dtype : atr - accumulation dtype. - """ - self.older_attr = {} - self.ops = ["nn.conv2d", "nn.dense"] - self.attr_key = "FTVMMixedPrecisionConversionType" - - def __enter__(self): - for op_name in self.ops: - op = relay.op.get(op_name) - self.older_attr[op_name] = op.get_attr(self.attr_key) - op.reset_attr(self.attr_key) - op.set_attr(self.attr_key, mixed_precision_rule) - return self - - def __exit__(self, ptype, value, trace): - for op_name in self.ops: - op = relay.op.get(op_name) - op.reset_attr(self.attr_key) - if self.older_attr[op_name]: - op.set_attr(self.attr_key, self.older_attr[op_name]) - - -def convert_to_dtype(mod, dtype): - """Converts the operator datatypes""" - - global acc_dtype - if dtype in ["float16", "float16_acc32"]: - acc_dtype = "float16" if dtype == "float16" else "float32" - - mod = IRModule.from_expr(mod) - with AdrenoMixedPrecision(): - seq = tvm.transform.Sequential( - [relay.transform.InferType(), relay.transform.ToMixedPrecision()] - ) - with tvm.transform.PassContext( - config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}, opt_level=3 - ): - mod = seq(mod) - else: - print("Warn: Invald dtype conversion to ", dtype) - return mod - - -@tvm.register_func("adreno.mixed_precision_fp16") -def mixed_precision_hook_fp16(mod, params): - """TVMC hook api""" - - return convert_to_dtype(mod["main"], "float16") - - -@tvm.register_func("adreno.mixed_precision_fp16_acc32") -def mixed_precision_hook_fp16_acc32(mod, params): - """TVMC hook api""" - - return convert_to_dtype(mod["main"], "float16_acc32") diff --git a/tests/python/relay/opencl_texture/test_network.py b/tests/python/relay/opencl_texture/test_network.py index 47bd82a2d1f1..1d0e996f9f97 100644 --- a/tests/python/relay/opencl_texture/test_network.py +++ b/tests/python/relay/opencl_texture/test_network.py @@ -22,20 +22,30 @@ import tvm from tvm import relay from tvm.contrib import utils -from tvm.relay.op.contrib import adreno from tvm.relay import testing from tvm.relay.op import register_mixed_precision_conversion from utils.adreno_utils import build_run_compare, get_model, gpu_preprocess -def _test_mobilenet_v1(remote, target, dtype): +def _test_mobilenet_v1(remote, target, calc_dtype, acc_dtype): mod, params, inputs, dtypes = get_model( "https://github.com/mlcommons/mobile_models/raw/main/v0_7/tflite/mobilenet_edgetpu_224_1.0_float.tflite", "mobilenet_edgetpu_224_1.0_float.tflite", "tflite", ) - if dtype == "float16" or dtype == "float16_acc32": - mod = adreno.convert_to_dtype(mod["main"], dtype) + if calc_dtype == "float16": + from tvm.driver.tvmc.transform import apply_graph_transforms + + mod = apply_graph_transforms( + mod, + { + "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d", "nn.dense"], + "mixed_precision_calculation_type": calc_dtype, + "mixed_precision_acc_type": acc_dtype, + }, + ) + build_run_compare(remote, mod, params, inputs, dtypes, target, []) @@ -44,21 +54,21 @@ def _test_mobilenet_v1(remote, target, dtype): @tvm.testing.parametrize_targets("opencl -device=adreno") @pytest.mark.skipif(tvm.testing.utils.IS_IN_CI, reason="CI doesn't support fp16(half datatypes)") def test_mobilenet_v1_fp16(remote, target): - _test_mobilenet_v1(remote, target, "float16") + _test_mobilenet_v1(remote, target, "float16", "float16") @pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/13443") @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") def test_mobilenet_v1_fp32(remote, target): - _test_mobilenet_v1(remote, target, "float32") + _test_mobilenet_v1(remote, target, "float32", "float32") @pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/13443") @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") def test_mobilenet_v1_fp16_acc32(remote, target): - _test_mobilenet_v1(remote, target, "float16_acc32") + _test_mobilenet_v1(remote, target, "float16", "float32") if __name__ == "__main__": From f83cdd01cc47298b8613d9c7dc583c659b9470b0 Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 9 Mar 2023 15:56:15 +0530 Subject: [PATCH 16/20] Update docs/how_to/deploy/adreno.rst Co-authored-by: Egor Churaev --- docs/how_to/deploy/adreno.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index a88db8ccb44b..b2cb4574ef83 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -475,7 +475,7 @@ Please refer to the tutorial for the same Saved ``TVMCPackage`` can be used for native deployment using ``rtvm`` utility too. Also, please refer to `tvmc `_ - documentation for more details about the api interface. +documentation for more details about the api interface. **Relay Interface:** From cf1c6fe1d2e6bb814e15319e6b8faae2802d6d6b Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 9 Mar 2023 15:56:42 +0530 Subject: [PATCH 17/20] Update docs/how_to/deploy/adreno.rst Co-authored-by: Egor Churaev --- docs/how_to/deploy/adreno.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index b2cb4574ef83..c90fdebb29da 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -474,7 +474,7 @@ Please refer to the tutorial for the same Saved ``TVMCPackage`` can be used for native deployment using ``rtvm`` utility too. -Also, please refer to `tvmc `_ +Also, please refer to `tvmc `_ documentation for more details about the api interface. **Relay Interface:** From 255e0b19b89718dc17b570f9c0983472de0ba4b4 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Thu, 9 Mar 2023 16:19:59 +0530 Subject: [PATCH 18/20] * review comments --- docs/how_to/deploy/adreno.rst | 4 ++-- tests/scripts/setup-adreno-env.sh | 23 ++++++++++++++++------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index c90fdebb29da..7989346b95a0 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -456,8 +456,6 @@ Please refer to `rtvm `_ While integrating inside existing Android application TVM has multiple options. For JNI or CPP native we may use `C Runtime API `_ You may refer to ``rtvm``'s simplified interface `TVMRunner `_ also. -Additionally, TVM also supports Java interface through `TVM4J `_ - .. _python_interface: This section explains importing, auto tuning, compiling and running a model using python interface.\ @@ -489,6 +487,8 @@ This library module will be used to create graph runtime to deploy and run on ta Please refer to the tutorial `How To Deploy model on Adreno `_ for a step by step explanation of the same. +Additionally, TVM also supports Java interface through `TVM4J `_ + .. _application_integration: Application Integration diff --git a/tests/scripts/setup-adreno-env.sh b/tests/scripts/setup-adreno-env.sh index 01b8d78a9eed..fcb09d7632b4 100755 --- a/tests/scripts/setup-adreno-env.sh +++ b/tests/scripts/setup-adreno-env.sh @@ -17,10 +17,18 @@ # under the License. -ENVIRONMENT="default" +ENVIRONMENT="" RPC_PORT="" ADB_SERIAL="" +function usage() { + echo "Helper script to setp the environment for Tracker, RPC Device and for application" + echo "Usage (Help) : source setup-adreno-env.sh -h" + echo "Usage (Tracker): source setup-adreno-env.sh -e tracker -p " + echo "Usage (Device): source setup-adreno-env.sh -e device -p -d " + echo "Usage (Default/Application): source setup-adreno-env.sh -e default -p " +} + while [[ $# -gt 0 ]]; do case $1 in -e|--environment) @@ -38,10 +46,13 @@ while [[ $# -gt 0 ]]; do shift # past argument shift # past value ;; + -h|--help) + usage + return 0 + ;; -*|--*) - echo "Unknown option $1" - echo "Usage: source setup-adreno-env.sh -e -p -d " - return 1 + usage + return 0 ;; *) ;; @@ -97,8 +108,6 @@ case ${ENVIRONMENT} in ;; *) - echo "Unknown environment $ENVIRONMENT" - echo "Usage: source setup-adreno-env.sh -e -p -d " - return 1 + usage ;; esac From 53856454f7889d70c40ab0bc4a3a9d143d7cf1f4 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Fri, 10 Mar 2023 07:41:37 +0530 Subject: [PATCH 19/20] * lint error --- docs/how_to/deploy/adreno.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index 7989346b95a0..8a395cad1d9a 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -388,9 +388,9 @@ Commandline Tools ----------------- Here we describe entire compilation process using command line tools. TVM has command line utility -`tvmc `_ to perform +`tvmc `_ to perform model import, auto tuning, compilation and deply over rpc. -`tvmc `_ has many options to explore and try. +`tvmc `_ has many options to explore and try. **Model Import & Tuning:** Use the below command to import a model from any framework and auto tune the same. @@ -442,7 +442,7 @@ We can use below tvmc command to deploy on remore target via RPC based setup. python3 -m tvm.driver.tvmc run --device="cl" keras-resnet50.tar \ --rpc-key android --rpc-tracker 127.0.0.1:9190 --print-time -`tvmc `_ based run has more options +`tvmc `_ based run has more options to initialize the input in various modes like fill, random ..etc. ``tvmc`` based deployment generally a quick verification of compiled model on target from remote host via RPC setup. From 908d8e037054b0548a7954f602de42455b45b807 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Sat, 11 Mar 2023 08:24:53 +0530 Subject: [PATCH 20/20] * review --- docs/how_to/deploy/adreno.rst | 10 ++++++++++ tests/scripts/setup-adreno-env.sh | 10 +++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index 8a395cad1d9a..ed016a3ff744 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -372,6 +372,13 @@ Now, the below comand can run TVM RPC on remote android device with id ``abcdefg ./tests/scripts/ci.py adreno -i # Launch a new shell on adreno docker. source tests/scripts/setup-adreno-env.sh -e device -p 9190 -d abcdefgh +Further, below command can be used to query the RPC setup details on any other docker terminals. + +:: + + ./tests/scripts/ci.py adreno -i # Launch a new shell on adreno docker. + source tests/scripts/setup-adreno-env.sh -e query -p 9190 + **Manual RPC Setup:** @@ -458,6 +465,9 @@ You may refer to ``rtvm``'s simplified interface `TVMRunner " echo "Usage (Device): source setup-adreno-env.sh -e device -p -d " - echo "Usage (Default/Application): source setup-adreno-env.sh -e default -p " + echo "Usage (Query): source setup-adreno-env.sh -e query -p " } while [[ $# -gt 0 ]]; do @@ -67,8 +67,8 @@ echo "ADB_SERIAL = ${ADB_SERIAL}" function def_environment() { source tests/scripts/setup-pytest-env.sh export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python - export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" - export TVM_TRACKER_HOST=127.0.0.1 + export LD_LIBRARY_PATH="${TVM_PATH}/build:${LD_LIBRARY_PATH}" + export TVM_TRACKER_HOST=0.0.0.0 export TVM_TRACKER_PORT=$RPC_PORT export RPC_DEVICE_KEY="android" export RPC_TARGET="adreno" @@ -101,7 +101,7 @@ case ${ENVIRONMENT} in adb shell "cd /data/local/tmp/tvm_ci; killall -9 tvm_rpc_ci; sleep 2; LD_LIBRARY_PATH=/data/local/tmp/tvm_ci/ ./tvm_rpc_ci server --host=0.0.0.0 --port=5000 --port-end=5010 --tracker=127.0.0.1:${TVM_TRACKER_PORT} --key=${RPC_DEVICE_KEY}" ;; - "default") + "query") def_environment echo "Setting dev environment with Tracker Port : $TVM_TRACKER_HOST} and the available devices are" python3 -m tvm.exec.query_rpc_tracker --port ${TVM_TRACKER_PORT}