Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/hexagon_launcher/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
)

target = tvm.target.hexagon('v68', link_params=True)
target = tvm.target.hexagon('v68')
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, tvm.target.Target(target, host=target), params=params, mod_name="default")

Expand Down Expand Up @@ -172,7 +172,7 @@ A sample output JSON from running the Inception V3 model may look like

When using AoT, the `target` needs to be `llvm`:
```
aot_target = "llvm -keys=hexagon -link-params=0 -mattr=+hvxv69,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp -mcpu=hexagonv69 -mtriple=hexagon"
aot_target = "llvm -keys=hexagon -mattr=+hvxv69,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp -mcpu=hexagonv69 -mtriple=hexagon"
aot_host_target = aot_target
```

Expand Down
2 changes: 1 addition & 1 deletion apps/howto_deploy/prepare_test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def prepare_test_libs(base_path):
fadd_dylib.export_library(dylib_path)

# Compile library in system library mode
fadd_syslib = tvm.build(s, [A, B], "llvm --system-lib", name="addonesys")
fadd_syslib = tvm.build(s, [A, B], "llvm", name="addonesys")
syslib_path = os.path.join(base_path, "test_addone_sys.o")
fadd_syslib.save(syslib_path)

Expand Down
7 changes: 6 additions & 1 deletion apps/sgx/src/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def main():
)

with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(net, "llvm --system-lib", params=params)
graph, lib, params = relay.build(
net,
"llvm",
params=params,
runtime=tvm.relay.backend.Runtime("cpp", {"system-lib": True}),
)

build_dir = osp.abspath(sys.argv[1])
if not osp.isdir(build_dir):
Expand Down
9 changes: 7 additions & 2 deletions apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,15 @@ def build_graph_lib(opt_level):
shape_dict = {input_name: img_data.shape}

mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
target = "llvm -mtriple=wasm32-unknown-unknown -mattr=+simd128 --system-lib"
target = "llvm -mtriple=wasm32-unknown-unknown -mattr=+simd128"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we need to keep --system-lib passed to Executor config here. can you audit the other places you removed --system-lib and ensure the Executor passed to relay.build has it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked places that I removed and added to the executor as needed. Please let me know if I missed anything


with tvm.transform.PassContext(opt_level=opt_level):
factory = relay.build(mod, target=target, params=params)
factory = relay.build(
mod,
target=target,
params=params,
runtime=tvm.relay.backend.Runtime("cpp", {"system-lib": True}),
)

# Save the model artifacts to obj_file
obj_file = os.path.join(out_dir, "graph.o")
Expand Down
2 changes: 1 addition & 1 deletion gallery/how_to/tune_with_autoscheduler/ci_logs/matmul.json
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Keep a valid schedule for demonstraction. This is used to prevent flasky errors in CI.
{"i": [["[\"matmul_add\", 1024, 1024, 1024, \"float32\"]", "llvm -keys=cpu -link-params=0", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 2, 0, 1024, [2, 1, 4], 1], ["SP", 2, 4, 1024, [1, 1, 8], 1], ["SP", 2, 8, 1024, [4], 1], ["RE", 2, [0, 4, 1, 5, 8, 2, 6, 9, 3, 7]], ["FSP", 4, 0, 0, 2], ["FSP", 4, 3, 1, 2], ["RE", 4, [0, 3, 1, 4, 2, 5]], ["CA", 2, 4, 3], ["FU", 4, [0, 1]], ["AN", 4, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$8"], ["AN", 2, 9, 2], ["AN", 4, 4, 2]]]], "r": [[0.0044742], 0, 0.335558, 1607112214], "v": "v0.3"}
{"i": [["[\"matmul_add\", 1024, 1024, 1024, \"float32\"]", "llvm -keys=cpu", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 2, 0, 1024, [2, 1, 4], 1], ["SP", 2, 4, 1024, [1, 1, 8], 1], ["SP", 2, 8, 1024, [4], 1], ["RE", 2, [0, 4, 1, 5, 8, 2, 6, 9, 3, 7]], ["FSP", 4, 0, 0, 2], ["FSP", 4, 3, 1, 2], ["RE", 4, [0, 3, 1, 4, 2, 5]], ["CA", 2, 4, 3], ["FU", 4, [0, 1]], ["AN", 4, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$8"], ["AN", 2, 9, 2], ["AN", 4, 4, 2]]]], "r": [[0.0044742], 0, 0.335558, 1607112214], "v": "v0.3"}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Keep a valid schedule for demonstraction. This is used to prevent flasky errors in CI.
{"i": [["[\"sparse_dense\", 512, 512, 512, [9831, 16, 1], [9831], [33], \"float32\"]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1, ["sparse_dense_bsr_512_512_512_16_1_0.60_W_data", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indices", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indptr"]], [[], [["CI", 8], ["CI", 6], ["SP", 5, 0, 512, [1, 8], 1], ["FSP", 9, 0, 2, 1], ["SP", 5, 3, 32, [32], 1], ["FSP", 9, 2, 4, 1], ["RE", 5, [0, 3, 1, 4, 6, 2, 5, 7]], ["RE", 9, [0, 2, 1, 3]], ["CA", 5, 9, 1], ["CI", 4], ["FU", 9, [0, 1]], ["AN", 9, 0, 3], ["PR", 5, 0, "auto_unroll_max_step$0"], ["AN", 9, 2, 2]]]], "r": [[0.000957008], 0, 0.605709, 1614689820], "v": "v0.6"}
{"i": [["[\"sparse_dense\", 512, 512, 512, [9831, 16, 1], [9831], [33], \"float32\"]", "llvm -keys=cpu", [6, 64, 64, 0, 0, 0, 0, 0], "", 1, ["sparse_dense_bsr_512_512_512_16_1_0.60_W_data", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indices", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indptr"]], [[], [["CI", 8], ["CI", 6], ["SP", 5, 0, 512, [1, 8], 1], ["FSP", 9, 0, 2, 1], ["SP", 5, 3, 32, [32], 1], ["FSP", 9, 2, 4, 1], ["RE", 5, [0, 3, 1, 4, 6, 2, 5, 7]], ["RE", 9, [0, 2, 1, 3]], ["CA", 5, 9, 1], ["CI", 4], ["FU", 9, [0, 1]], ["AN", 9, 0, 3], ["PR", 5, 0, "auto_unroll_max_step$0"], ["AN", 9, 2, 2]]]], "r": [[0.000957008], 0, 0.605709, 1614689820], "v": "v0.6"}
4 changes: 2 additions & 2 deletions gallery/how_to/tune_with_autotvm/tune_relay_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def tune_and_evaluate(tuning_opt):
#
# Evaluation of the network been tuned on graph level:
# Compile...
# Config for target=llvm -keys=cpu -link-params=0, workload=('dense_nopack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
# Config for target=llvm -keys=cpu -link-params=0, workload=('dense_pack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
# Config for target=llvm -keys=cpu, workload=('dense_nopack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
# Config for target=llvm -keys=cpu, workload=('dense_pack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
# Evaluate inference time cost...
# Mean inference time (std dev): 3.16 ms (0.03 ms)
4 changes: 2 additions & 2 deletions gallery/how_to/work_with_microtvm/micro_tvmc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ wget https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/e
#
# bash
tvmc compile magic_wand.tflite \
--target='c -keys=cpu -link-params=0 -model=host' \
--target='c -keys=cpu -model=host' \
--runtime=crt \
--runtime-crt-system-lib 1 \
--executor='graph' \
Expand All @@ -111,7 +111,7 @@ tvmc compile magic_wand.tflite \
# bash
# This will generate a ``model.tar`` file which contains TVM compiler output files. To run this command for
# a different Zephyr device, you need to update ``target``. For instance, for ``nrf5340dk_nrf5340_cpuapp`` board
# the target is ``--target='c -keys=cpu -link-params=0 -model=nrf5340dk'``.
# the target is ``--target='c -keys=cpu -model=nrf5340dk'``.
#


Expand Down
2 changes: 0 additions & 2 deletions gallery/tutorial/auto_scheduler_matmul_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
testing.utils.install_request_hook(depth=3)
# sphinx_gallery_end_ignore

import os

import numpy as np
import tvm
from tvm import te, auto_scheduler
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/hexagon/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def terminate_rpc_servers():

aot_host_target = tvm.testing.parameter(
"c",
"llvm -keys=hexagon -link-params=0 "
"llvm -keys=hexagon "
"-mattr=+hvxv68,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp "
"-mcpu=hexagonv68 -mtriple=hexagon",
)
Expand Down
78 changes: 0 additions & 78 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,69 +274,6 @@ def _build_module_no_factory(mod, target=None, target_host=None, params=None, mo
return _build_module_no_factory_impl(mod, target, target_host, params, mod_name)


def _reconstruct_from_deprecated_options(deprecated_params_target):
executor = None
runtime = None

deprecated_executor = None
deprecated_executor_args = {}
if "executor" in deprecated_params_target.attrs:
_deprecated_target_param_warning("Executor", "executor")
deprecated_executor = deprecated_params_target.attrs.get("executor", "graph")
if "interface-api" in deprecated_params_target.attrs:
_deprecated_target_sub_param_warning("Executor", "interface-api")
deprecated_executor_args.update(
{"interface-api": deprecated_params_target.attrs["interface-api"]}
)
if "unpacked-api" in deprecated_params_target.attrs:
_deprecated_target_sub_param_warning("Executor", "unpacked-api")
deprecated_executor_args.update(
{"unpacked-api": deprecated_params_target.attrs["unpacked-api"]}
)
if (
"link-params" in deprecated_params_target.attrs
and deprecated_params_target.attrs["link-params"]
):
_deprecated_target_sub_param_warning("Executor", "link-params")
if deprecated_executor != "aot":
deprecated_executor_args.update(
{"link-params": deprecated_params_target.attrs["link-params"]}
)
if deprecated_executor or deprecated_executor_args:
executor = Executor(deprecated_executor or "graph", deprecated_executor_args)

deprecated_runtime = None
deprecated_runtime_args = {}
if "runtime" in deprecated_params_target.attrs:
_deprecated_target_param_warning("Runtime", "runtime")
deprecated_runtime = deprecated_params_target.attrs.get("runtime", "cpp")
if deprecated_runtime == "c":
deprecated_runtime = "crt"
if "system-lib" in deprecated_params_target.attrs:
_deprecated_target_sub_param_warning("Runtime", "system-lib")
deprecated_runtime_args.update({"system-lib": deprecated_params_target.attrs["system-lib"]})
if deprecated_runtime or deprecated_runtime_args:
runtime = Runtime(deprecated_runtime or "cpp", deprecated_runtime_args)

return executor, runtime


def _deprecated_target_param_warning(registry, param):
warnings.warn(
f"Please use {registry} (tvm.relay.backend.{registry}) "
f"instead of deprecated Target parameter -{param}",
DeprecationWarning,
)


def _deprecated_target_sub_param_warning(registry, param):
warnings.warn(
f"Please use {registry} (tvm.relay.backend.{registry}) parameter {param} "
f"instead of deprecated Target parameter -{param}",
DeprecationWarning,
)


def build(
ir_mod,
target=None,
Expand Down Expand Up @@ -415,17 +352,6 @@ def build(
assert len(raw_targets) > 0
target_host = raw_targets[0].host

# All of this logic is to raise deprecation warnings for various parameters
# TODO(Mousius) Remove these after some time
deprecated_params_target = target_host or list(raw_targets)[0]
deprecated_executor, deprecated_runtime = _reconstruct_from_deprecated_options(
deprecated_params_target
)
if deprecated_executor:
executor = deprecated_executor
if deprecated_runtime:
runtime = deprecated_runtime

# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
Expand Down Expand Up @@ -756,9 +682,5 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
if kind == "vm":
return VMExecutor(mod, device, raw_targets)
if kind == "aot":
# The AOT requires the executor as a target attribute.
# (The compilation paths for the other executors currently do not always provide this
# attribute, hence the above generic assert is more forgiving).
assert "executor" in raw_targets[0].attrs
return AotExecutor(mod, device, raw_targets)
raise RuntimeError("unknown execution strategy: {0}".format(kind))
19 changes: 1 addition & 18 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,6 @@ def hexagon(cpu_ver="v66", **kwargs):
Whether to use QFloat HVX instructions.
use_ieee_fp : bool (default: False)
Whether to use IEEE HVX instructions
link_params : bool (default: False)
Whether to link graph parameters into the LLVM module.

Note: Floating point support in HVX requires LLVM 14+.
"""
Expand Down Expand Up @@ -671,7 +669,6 @@ def get_arch_version(cpu_ver):
"llvm_options": None,
"use_qfloat": arch_version >= 68,
"use_ieee_fp": False,
"link_params": False,
}
config.update(kwargs)

Expand Down Expand Up @@ -732,24 +729,10 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument
args = [s.replace("=", "@") for s in llvm_options.split()]
return "--llvm-options=" + ",".join(args)

# TVM target attributes string
def create_tvm_options(cpu_ver, config): # pylint: disable=unused-argument
"""Create TVM target features string."""

features = {
"link_params": "link-params",
}
opts = ""
for k in config:
if k in features:
opts += " --" + features[k] + "=" + str(config[k])
return opts

target_str = create_llvm_target(cpu_ver, config)
llvm_str = create_llvm_options(cpu_ver, config)
tvm_str = create_tvm_options(cpu_ver, config)

args_list = target_str.split() + llvm_str.split() + tvm_str.split()
args_list = target_str.split() + llvm_str.split()

return Target(" ".join(["hexagon"] + args_list))

Expand Down
38 changes: 6 additions & 32 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<String>("mtriple")
.add_attr_option<String>("mfloat-abi")
.add_attr_option<String>("mabi")
.add_attr_option<Bool>("system-lib")
.add_attr_option<String>("runtime")
.add_attr_option<Integer>("num-cores")
.add_attr_option<Bool>("link-params", Bool(false))
.add_attr_option<Bool>("unpacked-api")
.add_attr_option<String>("interface-api")
// Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags
.add_attr_option<Bool>("fast-math") // implies all the below
.add_attr_option<Bool>("fast-math-nnan")
Expand All @@ -286,23 +281,16 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.set_target_parser(tvm::target::parsers::cpu::ParseTarget);

TVM_REGISTER_TARGET_KIND("c", kDLCPU)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Bool>("link-params", Bool(false))
.add_attr_option<String>("runtime")
.add_attr_option<String>("mcpu")
.add_attr_option<String>("march")
.add_attr_option<String>("executor")
.add_attr_option<Integer>("workspace-byte-alignment")
.add_attr_option<Integer>("constants-byte-alignment")
.add_attr_option<Bool>("unpacked-api")
.add_attr_option<String>("interface-api")
.set_default_keys({"cpu"})
.set_target_parser(tvm::target::parsers::cpu::ParseTarget);

TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
.add_attr_option<String>("mcpu")
.add_attr_option<String>("arch")
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_shared_memory_per_block")
.add_attr_option<Integer>("max_threads_per_block")
.add_attr_option<Integer>("thread_warp_size", Integer(32))
Expand All @@ -314,7 +302,6 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
.add_attr_option<String>("mcpu")
.add_attr_option<String>("mtriple")
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(1024))
.add_attr_option<Integer>("thread_warp_size", Integer(32))
.set_default_keys({"cuda", "gpu"})
Expand All @@ -324,7 +311,6 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<String>("mcpu")
.add_attr_option<String>("mtriple")
.add_attr_option<Array<String>>("mattr")
.add_attr_option<Bool>("system-lib")
// TODO(masahi): Support querying from a target device
// On RDNA cards, thread_warp_size should be 32
.add_attr_option<Integer>("max_num_threads", Integer(256))
Expand All @@ -335,7 +321,6 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.set_target_parser(UpdateROCmAttrs);

TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(1))
.add_attr_option<Integer>("texture_spatial_limit", Integer(16384))
Expand All @@ -346,15 +331,13 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
// information about this limitation can be found here:
// https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc
TVM_REGISTER_TARGET_KIND("metal", kDLMetal)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(16))
.add_attr_option<Integer>("max_function_args", Integer(31))
.set_default_keys({"metal", "gpu"});

TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<Bool>("system-lib")
// Feature support
.add_attr_option<Bool>("supports_float16")
.add_attr_option<Bool>("supports_float32", Bool(true))
Expand Down Expand Up @@ -393,39 +376,30 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.set_default_keys({"vulkan", "gpu"});

TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.set_default_keys({"webgpu", "gpu"});

TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL)
.add_attr_option<Bool>("system-lib")
TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL) // line break
.set_default_keys({"sdaccel", "hls"});

TVM_REGISTER_TARGET_KIND("aocl", kDLAOCL)
.add_attr_option<Bool>("system-lib")
TVM_REGISTER_TARGET_KIND("aocl", kDLAOCL) // line break
.set_default_keys({"aocl", "hls"});

TVM_REGISTER_TARGET_KIND("aocl_sw_emu", kDLAOCL)
.add_attr_option<Bool>("system-lib")
TVM_REGISTER_TARGET_KIND("aocl_sw_emu", kDLAOCL) // line break
.set_default_keys({"aocl", "hls"});

TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<String>("mcpu")
.add_attr_option<String>("mtriple")
.add_attr_option<Bool>("system-lib")
.add_attr_option<Bool>("link-params", Bool(false))
.add_attr_option<Array<String>>("llvm-options")
.set_default_keys({"hexagon"});

TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break
.add_attr_option<Bool>("system-lib");
TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU);

TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev) // line break
.add_attr_option<Bool>("system-lib");
TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev);

TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break
.add_attr_option<Bool>("system-lib");
TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU);

TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break
.add_attr_option<Array<Target>>("devices");
Expand Down
Loading