Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
718 changes: 529 additions & 189 deletions docs/how_to/deploy/adreno.rst

Large diffs are not rendered by default.

309 changes: 200 additions & 109 deletions gallery/how_to/deploy_models/deploy_model_on_adreno.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
"""
.. _tutorial-deploy-model-on-adreno:

Deploy the Pretrained Model on Adreno
=======================================
**Author**: Daniil Barinov
Deploy the Pretrained Model on Adreno
======================================
**Author**: Daniil Barinov, Siva Rama Krishna

This article is a step-by-step tutorial to deploy pretrained Pytorch ResNet-18 model on Adreno (on different precisions).

Expand Down Expand Up @@ -53,11 +53,17 @@
#
# adb devices
#
# Set the android device to use, if you have several devices connected to your computer.
#
# .. code-block:: bash
#
# export ANDROID_SERIAL=<device-hash>
#
# Then to upload these two files to the device you should use:
#
# .. code-block:: bash
#
# adb -s <device_hash> push {libtvm_runtime.so,tvm_rpc} /data/local/tmp
# adb push {libtvm_runtime.so,tvm_rpc} /data/local/tmp
#
# At this moment you will have «libtvm_runtime.so» and «tvm_rpc» on path /data/local/tmp on your device.
# Sometimes cmake can’t find «libc++_shared.so». Use:
Expand All @@ -70,7 +76,7 @@
#
# .. code-block:: bash
#
# adb -s <device_hash> push libc++_shared.so /data/local/tmp
# adb push libc++_shared.so /data/local/tmp
#
# We are now ready to run the TVM RPC Server.
# Launch rpc_tracker with following line in 1st console:
Expand All @@ -83,12 +89,12 @@
#
# .. code-block:: bash
#
# adb -s <device_hash> reverse tcp:9190 tcp:9190
# adb -s <device_hash> forward tcp:9090 tcp:9090
# adb -s <device_hash> forward tcp:9091 tcp:9091
# adb -s <device_hash> forward tcp:9092 tcp:9092
# adb -s <device_hash> forward tcp:9093 tcp:9093
# adb -s <device_hash> shell LD_LIBRARY_PATH=/data/local/tmp /data/local/tmp/tvm_rpc server --host=0.0.0.0 --port=9090 --tracker=127.0.0.1:9190 --key=android --port-end=9190
# adb reverse tcp:9190 tcp:9190
# adb forward tcp:5000 tcp:5000
# adb forward tcp:5002 tcp:5001
# adb forward tcp:5003 tcp:5002
# adb forward tcp:5004 tcp:5003
# adb shell LD_LIBRARY_PATH=/data/local/tmp /data/local/tmp/tvm_rpc server --host=0.0.0.0 --port=5000 --tracker=127.0.0.1:9190 --key=android --port-end=5100
#
# Before proceeding to compile and infer model, specify TVM_TRACKER_HOST and TVM_TRACKER_PORT
#
Expand All @@ -115,6 +121,73 @@
# android 1 1 0
# ----------------------------------

#################################################################
# Configuration
# -------------

import os
import torch
import torchvision
import tvm
from tvm import te
from tvm import relay, rpc
from tvm.contrib import utils, ndk
from tvm.contrib import graph_executor
from tvm.relay.op.contrib import clml
from tvm import autotvm

# Below are set of configuration that controls the behaviour of this script like
# local run or device run, target definitions, dtype setting and auto tuning enablement.
# Change these settings as needed if required.

# Adreno devices are efficient with float16 compared to float32
# Given the expected output doesn't effect by lowering precision
# it's advisable to use lower precision.
# We have a helper API to make the precision conversion simple and
# it supports dtype with "float16" and "float16_acc32" modes.
# Let's choose "float16" for calculation and "float32" for accumulation.

calculation_dtype = "float16"
acc_dtype = "float32"

# Specify Adreno target before compiling to generate texture
# leveraging kernels and get all the benefits of textures
# Note: This generated example running on our x86 server for demonstration.
# If running it on the Android device, we need to
# specify its instruction set. Set :code:`local_demo` to False if you want
# to run this tutorial with a real device over rpc.
local_demo = True

# by default on CPU target will execute.
# select 'cpu', 'opencl' and 'opencl -device=adreno'
test_target = "cpu"

# Change target configuration.
# Run `adb shell cat /proc/cpuinfo` to find the arch.
arch = "arm64"
target = tvm.target.Target("llvm -mtriple=%s-linux-android" % arch)

# Auto tuning is compute intensive and time taking task,
# hence disabling for default run. Please enable it if required.
is_tuning = False
tune_log = "adreno-resnet18.log"

# To enable OpenCLML accelerated operator library.
enable_clml = False

#################################################################
# Get a PyTorch Model
# -------------------
# Get resnet18 from torchvision models
model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()

# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

#################################################################
# Load a test image
# -----------------
Expand Down Expand Up @@ -146,133 +219,153 @@
img = np.expand_dims(img, 0)

#################################################################
# Load pretrained Pytorch model
# -----------------------------
# Create a Relay graph from a Pytorch ResNet-18 model
import os
import torch
import torchvision
import tvm
from tvm import te
from tvm import relay, rpc
from tvm.contrib import utils, ndk
from tvm.contrib import graph_executor

model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()

# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

# Convert PyTorch model to Relay module
# -------------------------------------
# TVM has frontend api for various frameworks under relay.frontend and now
# for pytorch model import we have relay.frontend.from_pytorch api.
# Input name can be arbitrary
input_name = "input0"
shape_list = [(input_name, img.shape)]

mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

#################################################################
# Precisions
# ----------
# Since TVM support Mixed Precision, we need to register mixed_precision_conversion:
from tvm.relay.op import register_mixed_precision_conversion

conv2d_acc = "float32"


@register_mixed_precision_conversion("nn.conv2d", level=11)
def conv2d_mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
global conv2d_acc
return [
relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
conv2d_acc,
mixed_precision_type,
]


@register_mixed_precision_conversion("nn.dense", level=11)
def conv2d_mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
global conv2d_acc
return [
relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
conv2d_acc,
mixed_precision_type,
]
# Adreno devices are efficient with float16 compared to float32
# Given the expected output doesn't effect by lowering precision
# it's advisable to use lower precision.

# TVM support Mixed Precision through ToMixedPrecision transformation pass.
# We may need to register precision rules like precision type, accumultation
# datatype ...etc. for the required operators to override the default settings.
# The below helper api simplifies the precision conversions across the module.

#################################################################
# and also define the conversion function itself
def convert_to_dtype(mod, dtype):
# downcast to float16
if dtype == "float16" or dtype == "float16_acc32":
global conv2d_acc
conv2d_acc = "float16" if dtype == "float16" else "float32"
from tvm.ir import IRModule

mod = IRModule.from_expr(mod)
seq = tvm.transform.Sequential(
[relay.transform.InferType(), relay.transform.ToMixedPrecision()]
)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
return mod
# Calculation dtype is set to "float16" and accumulation dtype is set to "float32"
# in configuration section above.

from tvm.driver.tvmc.transform import apply_graph_transforms

#################################################################
# Let's choose "float16_acc32" for example.
dtype = "float16_acc32"
mod = convert_to_dtype(mod["main"], dtype)
dtype = "float32" if dtype == "float32" else "float16"

print(mod)
mod = apply_graph_transforms(
mod,
{
"mixed_precision": True,
"mixed_precision_ops": ["nn.conv2d", "nn.dense"],
"mixed_precision_calculation_type": calculation_dtype,
"mixed_precision_acc_type": acc_dtype,
},
)

#################################################################
# As you can see in the IR, the architecture now contains cast operations, which are
# needed to convert to FP16 precision.
# You can also use "float16" or "float32" precisions as other dtype options.

#################################################################
# Compile the model with relay
# ----------------------------
# Specify Adreno target before compiling to generate texture
# leveraging kernels and get all the benefits of textures
# Note: This generated example running on our x86 server for demonstration.
# If running it on the Android device, we need to
# specify its instruction set. Set :code:`local_demo` to False if you want
# to run this tutorial with a real device.
# Prepare TVM Target
# ------------------

local_demo = True
# This generated example running on our x86 server for demonstration.

# by default on CPU target will execute.
# select 'cpu', 'opencl' and 'vulkan'
test_target = "cpu"

# Change target configuration.
# Run `adb shell cat /proc/cpuinfo` to find the arch.
arch = "arm64"
target = tvm.target.Target("llvm -mtriple=%s-linux-android" % arch)
# To deply and tun on real target over RPC please set :code:`local_demo` to False in above configuration sestion.
# Also, :code:`test_target` is set to :code:`llvm` as this example to make compatible for x86 demonstration.
# Please change it to :code:`opencl` or :code:`opencl -device=adreno` for RPC target in configuration above.

if local_demo:
target = tvm.target.Target("llvm")
elif test_target == "opencl":
target = tvm.target.Target("opencl", host=target)
elif test_target == "vulkan":
target = tvm.target.Target("vulkan", host=target)
elif test_target.find("opencl"):
target = tvm.target.Target(test_target, host=target)

with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
##################################################################
# AutoTuning
# ----------
# The below few instructions can auto tune the relay module with xgboost being the tuner algorithm.

#################################################################
# Deploy the Model Remotely by RPC
# --------------------------------
# Using RPC you can deploy the model from host
# machine to the remote Adreno device
# Auto Tuning process involces stages of extracting the tasks, defining tuning congiguration and
# tuning each task for best performing kernel configuration.

# Get RPC related settings.
rpc_tracker_host = os.environ.get("TVM_TRACKER_HOST", "127.0.0.1")
rpc_tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190))
key = "android"

# Auto tuning is compute intensive and time taking task.
# It is set to False in above configuration as this script runs in x86 for demonstration.
# Please to set :code:`is_tuning` to True to enable auto tuning.

if is_tuning:
# Auto Tuning Stage 1: Extract tunable tasks
tasks = autotvm.task.extract_from_program(
mod, target=test_target, target_host=target, params=params
)

# Auto Tuning Stage 2: Define tuning configuration
tmp_log_file = tune_log + ".tmp"
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(
build_func=ndk.create_shared, timeout=15
), # Build the test kernel locally
runner=autotvm.RPCRunner( # The runner would be on a remote device.
key, # RPC Key
host=rpc_tracker_host, # Tracker host
port=int(rpc_tracker_port), # Tracker port
number=3, # Number of runs before averaging
timeout=600, # RPC Timeout
),
)
n_trial = 1024 # Number of iteration of training before choosing the best kernel config
early_stopping = False # Can be enabled to stop tuning while the loss is not minimizing.

# Auto Tuning Stage 3: Iterate through the tasks and tune.
from tvm.autotvm.tuner import XGBTuner

for i, tsk in enumerate(reversed(tasks[:3])):
print("Task:", tsk)
prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
tuner_obj = XGBTuner(tsk, loss_type="rank")

tsk_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(
n_trial=tsk_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file),
],
)
# Auto Tuning Stage 4: Pick the best performing configurations from the overall log.
autotvm.record.pick_best(tmp_log_file, tune_log)

#################################################################
# Enable OpenCLML Offloading
# --------------------------
# OpenCLML offloading will try to accelerate supported operators
# by using OpenCLML proprietory operator library.

# By default :code:`enable_clml` is set to False in above configuration section.

if not local_demo and enable_clml:
mod = clml.partition_for_clml(mod, params)

#################################################################
# Compilation
# -----------
# Use tuning cache if exists.
if os.path.exists(tune_log):
with autotvm.apply_history_best(tune_log):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
else:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)

#################################################################
# Deploy the Model Remotely by RPC
# --------------------------------
# Using RPC you can deploy the model from host
# machine to the remote Adreno device
if local_demo:
remote = rpc.LocalSession()
else:
Expand All @@ -282,10 +375,8 @@ def convert_to_dtype(mod, dtype):

if local_demo:
dev = remote.cpu(0)
elif test_target == "opencl":
elif test_target.find("opencl"):
dev = remote.cl(0)
elif test_target == "vulkan":
dev = remote.vulkan(0)
else:
dev = remote.cpu(0)

Expand Down
Loading