Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions apps/benchmark/gpu_imagenet_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
see README.md for the usage and results of this script.
"""
import argparse
import threading

import numpy as np

Expand All @@ -14,6 +15,26 @@
from util import get_network


def benchmark(network, target):
net, params, input_shape, output_shape = get_network(network, batch_size=1)

with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)

# create runtime
ctx = tvm.context(str(target), 0)
module = runtime.create(graph, lib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params)

# evaluate
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=args.repeat)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices=
Expand All @@ -29,6 +50,7 @@
parser.add_argument("--target", type=str,
choices=['cuda', 'opencl', 'rocm', 'nvptx', 'metal'], default='cuda',
help="The tvm compilation target")
parser.add_argument("--thread", type=int, default=1, help="The number of threads to be run.")
args = parser.parse_args()

dtype = 'float32'
Expand All @@ -44,20 +66,16 @@
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
print("--------------------------------------------------")
for network in networks:
net, params, input_shape, output_shape = get_network(network, batch_size=1)

with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)

# create runtime
ctx = tvm.context(str(target), 0)
module = runtime.create(graph, lib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params)

# evaluate
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=args.repeat)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
if args.thread == 1:
benchmark(network, target)
else:
threads = list()
for n in range(args.thread):
thread = threading.Thread(target=benchmark, args=([network, target]), name="thread%d" % n)
threads.append(thread)

for thread in threads:
thread.start()

for thread in threads:
thread.join()
2 changes: 1 addition & 1 deletion src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
if (initialized_) return;
std::lock_guard<std::mutex> lock(this->mu);
if (initialized_) return;
initialized_ = true;
if (context != nullptr) return;
// matched platforms
std::vector<cl_platform_id> platform_ids = cl::GetPlatformIDs();
Expand Down Expand Up @@ -271,6 +270,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
clCreateCommandQueue(this->context, did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code);
}
initialized_ = true;
}

TVM_REGISTER_GLOBAL("device_api.opencl")
Expand Down