Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 32 additions & 37 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(self, timeout=10, n_parallel=None, build_func='default'):

if isinstance(build_func, str):
if build_func == 'default':
build_func = default_build_func
build_func = make_build_func()
elif build_func == 'ndk':
build_func = android_ndk_build_func
build_func = make_build_func("so", ndk.create_shared)
else:
raise ValueError("Invalid build_func" + build_func)

Expand Down Expand Up @@ -349,46 +349,41 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)


def default_build_func(measure_input, tmp_dir, **kwargs):
def make_build_func(file_ext="tar", fcompile=None, **export_library_kwargs):
"""
Default build func. This can work for cuda, opencl, llvm backend
make_build_func which crates build func.

Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
file_ext: str
filename extension of the exported shared library
fcompile: function
Compilation function to create dynamic library
export_library_kwargs: optional
Additional arguments passed to fcompile
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)


def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
"""
Build function for android device using ndk.

Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, ndk.create_shared)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
def build_func(measure_input, tmp_dir, **kwargs):
"""
build_func. This can work for cuda, opencl, llvm and ndk (android) backend

Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
kwargs: optional
Additional arguments passed to _build_func_common
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (getrandbits(64), file_ext))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, fcompile, **export_library_kwargs)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
return build_func


def run_through_rpc(measure_input, build_result,
Expand Down
31 changes: 30 additions & 1 deletion tutorials/autotvm/tune_relay_mobile_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,27 @@ def get_network(name, batch_size):
# This target host is used for cross compilation. You can query it by :code:`gcc -v` on your device.
target_host = 'llvm -target=aarch64-linux-gnu'

# Use Cross-Compiler to export lib
#
# False - use default build_func - lib will be exported as .tar file containing object file (.o file).
# Edge device needs build-essential package to be installed (gcc, g++, ld)
#
# True - use cross compiler to export lib as .so file.
# Cross-compiler must be installed on tuner box and specified below ("cc" parameter).
# Edge device does not need build-essential package (gcc, g++, ld) to be installed
use_cross_compiler_to_export_lib = False

# Cross-Compiler which will be used to create shared library (.so file) for edge device
# To install Cross-Compiler on tuner box use the following commands:
#
# for ARMv7 gnueabihf (hf stands for HardFloat)
# $ apt install gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf
#
# for ARMv8 aarch64
# $ apt install gcc-aarch64-linux-gnu g++-aarch64-linux-gnu
cc = 'aarch64-linux-gnu-g++'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add more follow up instructions on how to install such cross compiler as a part of note block, set the following flags to use cross gcc and default to False(because the on device compiler is still the easiest in many cases)

Copy link
Contributor Author

@apivovarov apivovarov Mar 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the PR:

  • set use_cross_compiler_to_export_lib=False by default
  • added instructions on how to install armv7/8 cross-compilers

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



# Also replace this with the device key in your tracker
device_key = 'rk3399'

Expand All @@ -185,6 +206,13 @@ def get_network(name, batch_size):
log_file = "%s.%s.log" % (device_key, network)
dtype = 'float32'

build_func = 'default'
if use_android:
build_func = 'ndk'
elif use_cross_compiler_to_export_lib:
from tvm.autotvm.measure.measure_methods import make_build_func
build_func = make_build_func('so', cc=cc)

tuning_option = {
'log_filename': log_file,

Expand All @@ -194,7 +222,8 @@ def get_network(name, batch_size):

'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(
build_func='ndk' if use_android else 'default'),
build_func=build_func,
),
runner=autotvm.RPCRunner(
device_key, host='0.0.0.0', port=9190,
number=10,
Expand Down