From b53ed98002f21981b9bb006b755e6c38aa182560 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Wed, 4 Jan 2023 16:11:49 +0530 Subject: [PATCH 1/2] [BENCHMARKS][CLML] Adreno benchmarks with CLML BYOC path added Various benchmarks enabled for CLML BYOC backend for Adreno GPU Networks resnet-18, resnet-34, resnet-50, densenet-121, inception_v3, mobilenetv1, squeezenet_v1.0, squeezenet_v1.1 are added with FP16 and FP32 dtypes. --- apps/benchmark/README.md | 15 +- .../benchmark/adreno/adreno_gpu_bench_clml.py | 282 ++++++++++++++++++ apps/benchmark/adreno/bench.sh | 5 + python/tvm/relay/op/contrib/clml.py | 31 ++ src/runtime/contrib/clml/clml_runtime.cc | 47 ++- tests/scripts/ci.py | 16 + 6 files changed, 391 insertions(+), 5 deletions(-) create mode 100755 apps/benchmark/adreno/adreno_gpu_bench_clml.py diff --git a/apps/benchmark/README.md b/apps/benchmark/README.md index ccac79df47d8..44c54b1cf297 100644 --- a/apps/benchmark/README.md +++ b/apps/benchmark/README.md @@ -134,10 +134,23 @@ python3 gpu_imagenet_bench.py --model gfx900 --target rocm Adreno benchmarks are automated over the docker - [ci_adreno](https://github.com/apache/tvm/blob/main/docker/Dockerfile.ci_adreno). Adreno docker share the Android devices from host. It is adviced to have host adb version same as docker, which is ```1.0.41``` -Below command runs all the benchmarks over given Android device. +Below command runs all (OpenCL native, CLML SDK) the benchmarks over given Android device. ```bash export ANDROID_SERIAL= ./tests/scripts/ci.py adreno -b ``` +Below command runs all OpenCL native benchmarks over given Android device. +```bash +export ANDROID_SERIAL= +./tests/scripts/ci.py adreno -n +``` +CLML SDK benchmarks require CLML SDK path to be exported and the SDK version should match with target device's SDK version. + +Below command runs all CLML SDK benchmarks over given Android device. +```bash +export ADRENO_OPENCL= +export ANDROID_SERIAL= +./tests/scripts/ci.py adreno -c +``` Note: Tuning cache is implicite through tophub repo for all the benchmarks and is tuned over Snapdragon Gen 1. diff --git a/apps/benchmark/adreno/adreno_gpu_bench_clml.py b/apps/benchmark/adreno/adreno_gpu_bench_clml.py new file mode 100755 index 000000000000..17c483fe2c76 --- /dev/null +++ b/apps/benchmark/adreno/adreno_gpu_bench_clml.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Benchmark script for various models on Adreno GPU. +""" +import argparse + +import numpy as np + +import os +import sys +import tvm +from tvm import te +from tvm.relay import testing +from tvm.contrib.utils import tempdir +from tvm.relay.op.contrib import clml +import tvm.contrib.graph_executor as runtime +from tvm import relay +from tvm import autotvm +from tvm.contrib import utils, ndk + + +def get_network(name, batch_size, dtype="float32"): + """Get the symbol definition and random weight of a network + + Parameters + ---------- + name: str + The name of the network, can be 'resnet-18', 'resnet-50', 'vgg-16', 'inception_v3', 'mobilenet', ... + batch_size: int + batch size + dtype: str + Data type + + Returns + ------- + net: tvm.IRModule + The relay function of network definition + params: dict + The random parameters for benchmark + input_shape: tuple + The shape of input tensor + output_shape: tuple + The shape of output tensor + """ + input_shape = (batch_size, 3, 224, 224) + output_shape = (batch_size, 1000) + + if name == "mobilenet": + net, params = testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype) + elif name == "inception_v3": + input_shape = (batch_size, 3, 299, 299) + net, params = testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) + elif "resnet" in name: + n_layer = int(name.split("-")[1]) + net, params = testing.resnet.get_workload( + num_layers=n_layer, batch_size=batch_size, dtype=dtype + ) + elif "vgg" in name: + n_layer = int(name.split("-")[1]) + net, params = testing.vgg.get_workload( + num_layers=n_layer, batch_size=batch_size, dtype=dtype + ) + elif "densenet" in name: + n_layer = int(name.split("-")[1]) + net, params = testing.densenet.get_workload( + densenet_size=n_layer, batch_size=batch_size, dtype=dtype + ) + elif "squeezenet" in name: + version = name.split("_v")[1] + net, params = testing.squeezenet.get_workload( + batch_size=batch_size, version=version, dtype=dtype + ) + elif name == "mxnet": + # an example for mxnet model + from mxnet.gluon.model_zoo.vision import get_model + + block = get_model("resnet18_v1", pretrained=True) + net, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype) + net = net["main"] + net = relay.Function( + net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs + ) + net = tvm.IRModule.from_expr(net) + else: + raise ValueError("Unsupported network: " + name) + + return net, params, input_shape, output_shape + + +def print_progress(msg): + """print progress message + + Parameters + ---------- + msg: str + The message to print + """ + sys.stdout.write(msg + "\r") + sys.stdout.flush() + + +def tune_tasks( + tasks, + measure_option, + n_trial=1024, + early_stopping=None, + log_filename="tuning.log", +): + from tvm.autotvm.tuner import XGBTuner + + tmp_log_file = log_filename + ".tmp" + + for i, tsk in enumerate(reversed(tasks)): + print("Task: ", tsk) + prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) + tuner_obj = XGBTuner(tsk, loss_type="rank") + + tsk_trial = min(n_trial, len(tsk.config_space)) + tuner_obj.tune( + n_trial=tsk_trial, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(tsk_trial, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file), + ], + ) + + autotvm.record.pick_best(tmp_log_file, log_filename) + + +def evaluate_network(network, target, target_host, dtype, repeat): + print_progress(network) + net, params, input_shape, output_shape = get_network(network, batch_size=1, dtype=dtype) + + # Auto Tuning + tune_log = "adreno-" + network + "-" + dtype + ".log" + tuning_options = { + "log_filename": tune_log, + "early_stopping": None, + "measure_option": autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15), + runner=autotvm.RPCRunner( + args.rpc_key, + host=args.host, + port=args.port, + number=3, + timeout=600, + ), + ), + } + if args.tune: + tasks = autotvm.task.extract_from_program( + net, target=target, target_host=target_host, params=params + ) + tune_tasks(tasks, **tuning_options) + + print_progress("%-20s building..." % network) + + # Build the tuning log + if os.path.exists(tune_log): + with autotvm.apply_history_best(tune_log): + with tvm.transform.PassContext(opt_level=3): + net = clml.partition_for_clml(net, params) + lib = relay.build( + net, target=tvm.target.Target(target, host=target_host), params=params + ) + else: + with tvm.transform.PassContext(opt_level=3): + net = clml.partition_for_clml(net, params) + + lib = relay.build( + net, target=tvm.target.Target(target, host=target_host), params=params + ) + + tmp = tempdir() + + filename = "%s.so" % network + lib.export_library(tmp.relpath(filename), ndk.create_shared) + + # upload library and params + print_progress("%-20s uploading..." % network) + + # connect to remote device + tracker = tvm.rpc.connect_tracker(args.host, args.port) + remote = tracker.request(args.rpc_key) + + dev = remote.device(str(target), 0) + remote.upload(tmp.relpath(filename)) + + rlib = remote.load_module(filename) + module = runtime.GraphModule(rlib["default"](dev)) + data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) + module.set_input("data", data_tvm) + + # evaluate + print_progress("%-20s evaluating..." % network) + ftimer = module.module.time_evaluator("run", dev, number=1, repeat=repeat) + prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond + print( + "%-20s %-19s (%s)" + % (network + "-" + dtype, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)) + ) + return (np.mean(prof_res), np.std(prof_res)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--network", + type=str, + choices=[ + "resnet-18", + "resnet-34", + "resnet-50", + "vgg-16", + "vgg-19", + "densenet-121", + "inception_v3", + "mobilenet", + "squeezenet_v1.0", + "squeezenet_v1.1", + ], + help="The name of neural network", + ) + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-key", type=str, default="android") + parser.add_argument("--repeat", type=int, default=30) + parser.add_argument("--tune", type=bool, default=False) + args = parser.parse_args() + + if args.network is None: + networks = [ + "resnet-18", + "resnet-34", + "resnet-50", + # "vgg-16", + # "vgg-19", + "densenet-121", + "inception_v3", + "mobilenet", + "squeezenet_v1.0", + "squeezenet_v1.1", + ] + else: + networks = [args.network] + + target = "opencl" + target_host = "llvm -mtriple=arm64-linux-android" + + print("--------------------------------------------------") + print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)")) + print("--------------------------------------------------") + + results = {} + + for network in networks: + ftime = evaluate_network(network, target, target_host, "float32", args.repeat) + results[network + "-float32"] = ftime + ftime = evaluate_network(network, target, target_host, "float16", args.repeat) + results[network + "-float16"] = ftime + + print("----------------------------------------------------------------------") + print("%-30s %-30s" % ("Network Name", "Mean Inference Time (std dev)")) + print("----------------------------------------------------------------------") + for key, val in results.items(): + print("%-30s %-30s (%s)" % (key, "%.2f ms" % val[0], "%.2f ms" % val[1])) diff --git a/apps/benchmark/adreno/bench.sh b/apps/benchmark/adreno/bench.sh index 7d46685b8654..7f9adeea5251 100755 --- a/apps/benchmark/adreno/bench.sh +++ b/apps/benchmark/adreno/bench.sh @@ -55,5 +55,10 @@ if [ "texture" == $1 ] ; then python3 apps/benchmark/adreno/adreno_gpu_bench_texture.py --host ${TVM_TRACKER_HOST} --port ${TVM_TRACKER_PORT} --rpc-key ${RPC_DEVICE_KEY} fi +if [ "clml" == $1 ] ; then + python3 apps/benchmark/adreno/adreno_gpu_bench_clml.py --host ${TVM_TRACKER_HOST} --port ${TVM_TRACKER_PORT} --rpc-key ${RPC_DEVICE_KEY} +fi + + kill ${TRACKER_PID} kill ${DEVICE_PID} diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 02e4f62bed24..0107a1891f18 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -19,9 +19,12 @@ import tvm from tvm import relay +from tvm.ir import Op from tvm._ffi import register_func from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name +from tvm.relay.expr_functor import ExprMutator +from tvm.relay.expr import Call, TupleGetItem from ...dataflow_pattern import wildcard, is_op, is_constant, is_tuple_get_item, is_tuple from .register import register_pattern_table @@ -47,6 +50,32 @@ def is_clml_runtime_enabled(): return check_enabled() return False +class RemoveDropout(ExprMutator): + """ + Removes all nn.dropout from an expr. + """ + + def visit_tuple_getitem(self, op: TupleGetItem) -> relay.expr.Expr: + visit = super().visit_tuple_getitem(op) + if visit.index != 0: + return visit + if ( + isinstance(visit.tuple_value, Call) + and isinstance(visit.tuple_value.op, Op) + and visit.tuple_value.op.name == "nn.dropout" + and visit.index == 0 + ): + return visit.tuple_value.args[0] + return visit + + +@transform.function_pass(opt_level=0) +class RemoveDropoutPass: + def transform_function( + self, func: relay.function.Function, mod: tvm.IRModule, _: tvm.transform.PassContext + ) -> relay.function.Function: + return RemoveDropout().visit(func) + def partition_for_clml(mod, params=None): """Partition the graph greedily offloading supported @@ -70,6 +99,7 @@ def partition_for_clml(mod, params=None): seq = tvm.transform.Sequential( [ transform.InferType(), + RemoveDropoutPass(), transform.FoldConstant(), transform.MergeComposite(clml_pattern_table()), transform.AnnotateTarget("clml", False), @@ -289,6 +319,7 @@ def check_default_op(extract): ("clml.global_max_pool2d", is_op("nn.global_max_pool2d")(wildcard()), check_default_op), ("clml.relu", is_op("nn.relu")(wildcard()), check_default_op), ("clml.clip", is_op("clip")(wildcard()), check_default_op), + ("clml.batch_flatten", is_op("nn.batch_flatten")(wildcard()), check_default_op), ] diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 6396fce4858b..c9ceeabad680 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -259,9 +259,14 @@ class CLMLRuntime : public JSONRuntimeBase { layer_.in_placeholder[i]->memory = static_cast( ((cl::BufferDescriptor*)const_cast(data_entry_[eid])->data)->buffer); cl_event cpy_evt = NULL; + cl_event *evt = &cpy_evt; + if (workspace->IsProfiling(tentry->device)) { + evts.resize(evts.size() + 1); + evt = &(evts.back()); + } result = h_ClmlIntf->clEnqueueCopyMLTensorDataQCOM( queue, layer_.in_placeholder[i]->tensor, layer_.in_placeholder[i]->memory, - layer_.inputs[i]->tensor, layer_.inputs[i]->memory, 0, NULL, &cpy_evt); + layer_.inputs[i]->tensor, layer_.inputs[i]->memory, 0, NULL, evt); ICHECK(result == CL_SUCCESS) << "clEnqueueCopyMLTensorDataQCOM:" << result; } else { DLDataType tvm_dtype = const_cast(data_entry_[eid])->dtype; @@ -277,7 +282,8 @@ class CLMLRuntime : public JSONRuntimeBase { } for (size_t i = 0; i < this->layer_.function.size(); ++i) { - if (getenv("CLML_PROFILING")) { + // Make CLML subgraphs accounted by OpenCLTimerNode. + if (getenv("CLML_PROFILING") || workspace->IsProfiling(tentry->device)) { evts.resize(evts.size() + 1); cl_event* evt = &(evts.back()); result = h_ClmlIntf->clEnqueueMLOpQCOM(queue, this->layer_.function[i], @@ -317,10 +323,14 @@ class CLMLRuntime : public JSONRuntimeBase { layer_.out_placeholder[i]->memory = static_cast( ((cl::BufferDescriptor*)const_cast(data_entry_[eid])->data)->buffer); cl_event cpy_evt = NULL; + cl_event *evt = &cpy_evt; + if (workspace->IsProfiling(tentry->device)) { + evts.resize(evts.size() + 1); + evt = &(evts.back()); + } result = h_ClmlIntf->clEnqueueCopyMLTensorDataQCOM( queue, layer_.outputs[i]->tensor, layer_.outputs[i]->memory, - layer_.out_placeholder[i]->tensor, layer_.out_placeholder[i]->memory, 0, NULL, - &cpy_evt); + layer_.out_placeholder[i]->tensor, layer_.out_placeholder[i]->memory, 0, NULL, evt); ICHECK(result == CL_SUCCESS) << "clEnqueueCopyMLTensorDataQCOM:" << result; } else { DLDataType tvm_dtype = const_cast(data_entry_[eid])->dtype; @@ -407,6 +417,10 @@ class CLMLRuntime : public JSONRuntimeBase { auto out = CreatePadLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); this->layer_.func_outs.push_back(out); + } else if ("nn.batch_flatten" == op_name) { + auto out = CreateBatchFlattenLayer(&layer_, node); + this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); + this->layer_.func_outs.push_back(out); } else if ("clip" == op_name) { auto out = CreateClipLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); @@ -1070,6 +1084,31 @@ class CLMLRuntime : public JSONRuntimeBase { return output; } + /*! + * \brief Create a Batch Flatten layer. + * + * \param layer The CLML layer to build. Containing inputs, outputs and the CLML output. + * \param node The JSON representation of the operator. + */ + std::shared_ptr CreateBatchFlattenLayer( + CachedLayer* layer, const JSONGraphNode& node) { + cl_int result = 0; + cl_ml_op_qcom op = NULL; + DLDataType tvm_dtype = node.GetOpDataType()[0]; + cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); + auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, + cl_dtype); + auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); + + result = h_ClmlIntf->clCreateMLOpReshapeQCOM(workspace->context, 0, input->tensor, + output->tensor, &op, tuning_cache); + ICHECK(op && result == CL_SUCCESS) << "Reshape Error:" << result; + + layer_.func_ins.push_back(input); + layer->function.push_back(op); + return output; + } + /*! * \brief Create a Reshape layer. * diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index 756b269d0e50..7b402c02d0b1 100755 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -728,12 +728,28 @@ def add_subparser( ], ), "benchmarks": ( + "run Adreno Benchmarks (Native OpenCL, CLML SDK)", + [ + "./apps/benchmark/adreno/bench.sh texture " + + os.environ.get("ANDROID_SERIAL", ""), + "./apps/benchmark/adreno/bench.sh clml " + + os.environ.get("ANDROID_SERIAL", ""), + ], + ), + "nativebenchmarks": ( "run Adreno Texture Benchmarks", [ "./apps/benchmark/adreno/bench.sh texture " + os.environ.get("ANDROID_SERIAL", ""), ], ), + "clmlbenchmarks": ( + "run Adreno CLML SDK Benchmarks", + [ + "./apps/benchmark/adreno/bench.sh clml " + + os.environ.get("ANDROID_SERIAL", ""), + ], + ), }, ), ] From ecf948a547243d2815f6b5b055013bb392f59520 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Wed, 4 Jan 2023 20:30:55 +0530 Subject: [PATCH 2/2] * lint error --- python/tvm/relay/op/contrib/clml.py | 1 + src/runtime/contrib/clml/clml_runtime.cc | 4 ++-- tests/scripts/ci.py | 6 ++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 0107a1891f18..77882917b1ad 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -50,6 +50,7 @@ def is_clml_runtime_enabled(): return check_enabled() return False + class RemoveDropout(ExprMutator): """ Removes all nn.dropout from an expr. diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index c9ceeabad680..1fb694a91201 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -259,7 +259,7 @@ class CLMLRuntime : public JSONRuntimeBase { layer_.in_placeholder[i]->memory = static_cast( ((cl::BufferDescriptor*)const_cast(data_entry_[eid])->data)->buffer); cl_event cpy_evt = NULL; - cl_event *evt = &cpy_evt; + cl_event* evt = &cpy_evt; if (workspace->IsProfiling(tentry->device)) { evts.resize(evts.size() + 1); evt = &(evts.back()); @@ -323,7 +323,7 @@ class CLMLRuntime : public JSONRuntimeBase { layer_.out_placeholder[i]->memory = static_cast( ((cl::BufferDescriptor*)const_cast(data_entry_[eid])->data)->buffer); cl_event cpy_evt = NULL; - cl_event *evt = &cpy_evt; + cl_event* evt = &cpy_evt; if (workspace->IsProfiling(tentry->device)) { evts.resize(evts.size() + 1); evt = &(evts.back()); diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index 7b402c02d0b1..700febd353d0 100755 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -732,8 +732,7 @@ def add_subparser( [ "./apps/benchmark/adreno/bench.sh texture " + os.environ.get("ANDROID_SERIAL", ""), - "./apps/benchmark/adreno/bench.sh clml " - + os.environ.get("ANDROID_SERIAL", ""), + "./apps/benchmark/adreno/bench.sh clml " + os.environ.get("ANDROID_SERIAL", ""), ], ), "nativebenchmarks": ( @@ -746,8 +745,7 @@ def add_subparser( "clmlbenchmarks": ( "run Adreno CLML SDK Benchmarks", [ - "./apps/benchmark/adreno/bench.sh clml " - + os.environ.get("ANDROID_SERIAL", ""), + "./apps/benchmark/adreno/bench.sh clml " + os.environ.get("ANDROID_SERIAL", ""), ], ), },