Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion apps/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,23 @@ python3 gpu_imagenet_bench.py --model gfx900 --target rocm
Adreno benchmarks are automated over the docker - [ci_adreno](https://github.com/apache/tvm/blob/main/docker/Dockerfile.ci_adreno).
Adreno docker share the Android devices from host. It is adviced to have host adb version same as docker, which is ```1.0.41```

Below command runs all the benchmarks over given Android device.
Below command runs all (OpenCL native, CLML SDK) the benchmarks over given Android device.
```bash
export ANDROID_SERIAL=<ADB ID>
./tests/scripts/ci.py adreno -b
```
Below command runs all OpenCL native benchmarks over given Android device.
```bash
export ANDROID_SERIAL=<ADB ID>
./tests/scripts/ci.py adreno -n
```
CLML SDK benchmarks require CLML SDK path to be exported and the SDK version should match with target device's SDK version.

Below command runs all CLML SDK benchmarks over given Android device.
```bash
export ADRENO_OPENCL=<CLML SDK PATH>
export ANDROID_SERIAL=<ADB ID>
./tests/scripts/ci.py adreno -c
```

Note: Tuning cache is implicite through tophub repo for all the benchmarks and is tuned over Snapdragon Gen 1.
282 changes: 282 additions & 0 deletions apps/benchmark/adreno/adreno_gpu_bench_clml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Benchmark script for various models on Adreno GPU.
"""
import argparse

import numpy as np

import os
import sys
import tvm
from tvm import te
from tvm.relay import testing
from tvm.contrib.utils import tempdir
from tvm.relay.op.contrib import clml
import tvm.contrib.graph_executor as runtime
from tvm import relay
from tvm import autotvm
from tvm.contrib import utils, ndk


def get_network(name, batch_size, dtype="float32"):
"""Get the symbol definition and random weight of a network

Parameters
----------
name: str
The name of the network, can be 'resnet-18', 'resnet-50', 'vgg-16', 'inception_v3', 'mobilenet', ...
batch_size: int
batch size
dtype: str
Data type

Returns
-------
net: tvm.IRModule
The relay function of network definition
params: dict
The random parameters for benchmark
input_shape: tuple
The shape of input tensor
output_shape: tuple
The shape of output tensor
"""
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)

if name == "mobilenet":
net, params = testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
elif name == "inception_v3":
input_shape = (batch_size, 3, 299, 299)
net, params = testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif "resnet" in name:
n_layer = int(name.split("-")[1])
net, params = testing.resnet.get_workload(
num_layers=n_layer, batch_size=batch_size, dtype=dtype
)
elif "vgg" in name:
n_layer = int(name.split("-")[1])
net, params = testing.vgg.get_workload(
num_layers=n_layer, batch_size=batch_size, dtype=dtype
)
elif "densenet" in name:
n_layer = int(name.split("-")[1])
net, params = testing.densenet.get_workload(
densenet_size=n_layer, batch_size=batch_size, dtype=dtype
)
elif "squeezenet" in name:
version = name.split("_v")[1]
net, params = testing.squeezenet.get_workload(
batch_size=batch_size, version=version, dtype=dtype
)
elif name == "mxnet":
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model

block = get_model("resnet18_v1", pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
net = net["main"]
net = relay.Function(
net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
)
net = tvm.IRModule.from_expr(net)
else:
raise ValueError("Unsupported network: " + name)

return net, params, input_shape, output_shape


def print_progress(msg):
"""print progress message

Parameters
----------
msg: str
The message to print
"""
sys.stdout.write(msg + "\r")
sys.stdout.flush()


def tune_tasks(
tasks,
measure_option,
n_trial=1024,
early_stopping=None,
log_filename="tuning.log",
):
from tvm.autotvm.tuner import XGBTuner

tmp_log_file = log_filename + ".tmp"

for i, tsk in enumerate(reversed(tasks)):
print("Task: ", tsk)
prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
tuner_obj = XGBTuner(tsk, loss_type="rank")

tsk_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(
n_trial=tsk_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file),
],
)

autotvm.record.pick_best(tmp_log_file, log_filename)


def evaluate_network(network, target, target_host, dtype, repeat):
print_progress(network)
net, params, input_shape, output_shape = get_network(network, batch_size=1, dtype=dtype)

# Auto Tuning
tune_log = "adreno-" + network + "-" + dtype + ".log"
tuning_options = {
"log_filename": tune_log,
"early_stopping": None,
"measure_option": autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15),
runner=autotvm.RPCRunner(
args.rpc_key,
host=args.host,
port=args.port,
number=3,
timeout=600,
),
),
}
if args.tune:
tasks = autotvm.task.extract_from_program(
net, target=target, target_host=target_host, params=params
)
tune_tasks(tasks, **tuning_options)

print_progress("%-20s building..." % network)

# Build the tuning log
if os.path.exists(tune_log):
with autotvm.apply_history_best(tune_log):
with tvm.transform.PassContext(opt_level=3):
net = clml.partition_for_clml(net, params)
lib = relay.build(
net, target=tvm.target.Target(target, host=target_host), params=params
)
else:
with tvm.transform.PassContext(opt_level=3):
net = clml.partition_for_clml(net, params)

lib = relay.build(
net, target=tvm.target.Target(target, host=target_host), params=params
)

tmp = tempdir()

filename = "%s.so" % network
lib.export_library(tmp.relpath(filename), ndk.create_shared)

# upload library and params
print_progress("%-20s uploading..." % network)

# connect to remote device
tracker = tvm.rpc.connect_tracker(args.host, args.port)
remote = tracker.request(args.rpc_key)

dev = remote.device(str(target), 0)
remote.upload(tmp.relpath(filename))

rlib = remote.load_module(filename)
module = runtime.GraphModule(rlib["default"](dev))
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input("data", data_tvm)

# evaluate
print_progress("%-20s evaluating..." % network)
ftimer = module.module.time_evaluator("run", dev, number=1, repeat=repeat)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print(
"%-20s %-19s (%s)"
% (network + "-" + dtype, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))
)
return (np.mean(prof_res), np.std(prof_res))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--network",
type=str,
choices=[
"resnet-18",
"resnet-34",
"resnet-50",
"vgg-16",
"vgg-19",
"densenet-121",
"inception_v3",
"mobilenet",
"squeezenet_v1.0",
"squeezenet_v1.1",
],
help="The name of neural network",
)
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=9190)
parser.add_argument("--rpc-key", type=str, default="android")
parser.add_argument("--repeat", type=int, default=30)
parser.add_argument("--tune", type=bool, default=False)
args = parser.parse_args()

if args.network is None:
networks = [
"resnet-18",
"resnet-34",
"resnet-50",
# "vgg-16",
# "vgg-19",
"densenet-121",
"inception_v3",
"mobilenet",
"squeezenet_v1.0",
"squeezenet_v1.1",
]
else:
networks = [args.network]

target = "opencl"
target_host = "llvm -mtriple=arm64-linux-android"

print("--------------------------------------------------")
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
print("--------------------------------------------------")

results = {}

for network in networks:
ftime = evaluate_network(network, target, target_host, "float32", args.repeat)
results[network + "-float32"] = ftime
ftime = evaluate_network(network, target, target_host, "float16", args.repeat)
results[network + "-float16"] = ftime

print("----------------------------------------------------------------------")
print("%-30s %-30s" % ("Network Name", "Mean Inference Time (std dev)"))
print("----------------------------------------------------------------------")
for key, val in results.items():
print("%-30s %-30s (%s)" % (key, "%.2f ms" % val[0], "%.2f ms" % val[1]))
5 changes: 5 additions & 0 deletions apps/benchmark/adreno/bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,10 @@ if [ "texture" == $1 ] ; then
python3 apps/benchmark/adreno/adreno_gpu_bench_texture.py --host ${TVM_TRACKER_HOST} --port ${TVM_TRACKER_PORT} --rpc-key ${RPC_DEVICE_KEY}
fi

if [ "clml" == $1 ] ; then
python3 apps/benchmark/adreno/adreno_gpu_bench_clml.py --host ${TVM_TRACKER_HOST} --port ${TVM_TRACKER_PORT} --rpc-key ${RPC_DEVICE_KEY}
fi


kill ${TRACKER_PID}
kill ${DEVICE_PID}
32 changes: 32 additions & 0 deletions python/tvm/relay/op/contrib/clml.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import tvm

from tvm import relay
from tvm.ir import Op
from tvm._ffi import register_func
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.expr import Call, TupleGetItem

from ...dataflow_pattern import wildcard, is_op, is_constant, is_tuple_get_item, is_tuple
from .register import register_pattern_table
Expand All @@ -48,6 +51,33 @@ def is_clml_runtime_enabled():
return False


class RemoveDropout(ExprMutator):
"""
Removes all nn.dropout from an expr.
"""

def visit_tuple_getitem(self, op: TupleGetItem) -> relay.expr.Expr:
visit = super().visit_tuple_getitem(op)
if visit.index != 0:
return visit
if (
isinstance(visit.tuple_value, Call)
and isinstance(visit.tuple_value.op, Op)
and visit.tuple_value.op.name == "nn.dropout"
and visit.index == 0
):
return visit.tuple_value.args[0]
return visit


@transform.function_pass(opt_level=0)
class RemoveDropoutPass:
def transform_function(
self, func: relay.function.Function, mod: tvm.IRModule, _: tvm.transform.PassContext
) -> relay.function.Function:
return RemoveDropout().visit(func)


def partition_for_clml(mod, params=None):
"""Partition the graph greedily offloading supported
operators to CLML Library.
Expand All @@ -70,6 +100,7 @@ def partition_for_clml(mod, params=None):
seq = tvm.transform.Sequential(
[
transform.InferType(),
RemoveDropoutPass(),
transform.FoldConstant(),
transform.MergeComposite(clml_pattern_table()),
transform.AnnotateTarget("clml", False),
Expand Down Expand Up @@ -289,6 +320,7 @@ def check_default_op(extract):
("clml.global_max_pool2d", is_op("nn.global_max_pool2d")(wildcard()), check_default_op),
("clml.relu", is_op("nn.relu")(wildcard()), check_default_op),
("clml.clip", is_op("clip")(wildcard()), check_default_op),
("clml.batch_flatten", is_op("nn.batch_flatten")(wildcard()), check_default_op),
]


Expand Down
Loading