diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 19bcce3116b2..58ff4437f0a7 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -440,6 +440,19 @@ constexpr const char* kIsExternalCodegen = "is_external_codegen"; */ constexpr const char* kRelayToTIR = "RelayToTIR"; +/*! + * \brief String representation of the host's target architecture. + * + * Currently this is set to "arm_cpu" on ArmĀ®-based host architectures and "cpu" + * (which is synonymous with x86) everywhere else. + * + * TODO(@FranklandJack) dynamically detect host architecture and generalize for all targets. + */ +#if defined(__arm__) || defined(__aarch64__) +constexpr const char* kHostCPU = "arm_cpu"; +#else +constexpr const char* kHostCPU = "cpu"; +#endif } // namespace attr /*! diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index dc3b16aa82c2..addb74bb3661 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -152,6 +152,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): is_winograd_applicable = ( "float" in data.dtype and "float" in kernel.dtype + and not data.dtype.count("custom") and kh == 3 and kw == 3 and stride_h == 1 @@ -284,8 +285,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="depthwise_conv2d_nchw.x86", ) elif layout == "NHWC": - assert kernel_layout == "HWOI" - if target.features.has_asimd: + # TODO(@FranklandJack) + # Handle HWOI in arm_cpu schedules/compute definition. + if kernel_layout != "HWOI": + logger.warning( + """depthwise_conv2d with layout NHWC and HWOI + kernel layout is not optimized for arm_cpu target. + """ + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True), + wrap_topi_schedule(conv2d_generic.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.generic", + ) + + elif target.features.has_asimd: strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc), wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), @@ -304,8 +318,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): and kernel.shape[3] == 1 # channel_multiplier == 1 and out_type.dtype == "int32" and ( - (data.shape[3] % 4 == 0 and data.dtype == "int8" and target.features.has_dsp) - or (data.shape[3] % 2 == 0 and data.dtype == "int16") + ( + (data.shape[3] % 4 == 0 and data.dtype == "int8") + or (data.shape[3] % 2 == 0 and data.dtype == "int16") + ) + and target.features.has_dsp ) and (padding != "SAME" or data.shape[1] % stride_h == data.shape[2] % stride_w == 0) # Ideally we should check that kernel is a Relay constant, but strategy functions diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 8dc54a19b998..415410a11772 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -111,8 +111,7 @@ def qnn_conv2d_transpose_legalize(attrs, inputs, types): # Otherwise it needs to be broadcast. else: shift_data = relay.nn.bias_add( - relay.cast(data, dtype="int16"), - -relay.cast(input_zero_point, dtype="int16"), + relay.cast(data, dtype="int16"), -relay.cast(input_zero_point, dtype="int16") ) # If kernel zero point is a scalar, we can directly subtract it. @@ -123,8 +122,7 @@ def qnn_conv2d_transpose_legalize(attrs, inputs, types): # Otherwise it needs to be broadcast. else: shift_kernel = relay.nn.bias_add( - relay.cast(kernel, dtype="int16"), - -relay.cast(kernel_zero_point, dtype="int16"), + relay.cast(kernel, dtype="int16"), -relay.cast(kernel_zero_point, dtype="int16") ) return relay.nn.conv2d_transpose(shift_data, shift_kernel, **attrs) @@ -486,7 +484,10 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): if target.features.has_asimd and not other_options: return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d) # ARM prefers the dtypes to be same. - return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) + if types[0].dtype in ["int8", "uint8"]: + return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) + + return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d) @qnn_dense_legalize.register("arm_cpu") @@ -495,7 +496,10 @@ def _qnn_dense_legalize_arm_cpu(attrs, inputs, types): if target.features.has_asimd and not target.features.has_dotprod: return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense) # ARM prefers the dtypes to be same. - return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) + if types[0].dtype in ["int8", "uint8"]: + return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) + + return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense) ########################## diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index a478818084d5..96b6b221236a 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -23,10 +23,11 @@ from tvm import autotvm import tvm.contrib.nnpack -from ..utils import traverse_inline, get_const_tuple +from ..utils import traverse_inline, get_const_tuple, conv2d_infer_layout_helper from .. import nn from ..nn.utils import get_const_int, get_pad_tuple from ..nn.winograd_util import winograd_transform_matrices +from ..nn.conv2d import conv2d_infer_layout from .conv2d_spatial_pack import ( conv2d_spatial_pack_nchw, conv2d_spatial_pack_nhwc, @@ -509,3 +510,8 @@ def conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): def schedule_conv2d_nhwc_dsp(cfg, outs): """Create schedule for conv2d_nhwc_dsp""" return conv2d_nhwc_dsp_schedule(cfg, outs) + + +@conv2d_infer_layout.register("arm_cpu") +def _conv2d_infer_layout(workload, cfg): + return conv2d_infer_layout_helper(workload, cfg) diff --git a/python/tvm/topi/arm_cpu/injective.py b/python/tvm/topi/arm_cpu/injective.py index 5c63e5a513db..5d0ba8cd60e8 100644 --- a/python/tvm/topi/arm_cpu/injective.py +++ b/python/tvm/topi/arm_cpu/injective.py @@ -69,8 +69,10 @@ def schedule_injective(outs): if list(s[x].op.axis): # do not vectorize for broadcast dtype = "uint16" if x.dtype == "bfloat16" else x.dtype - (io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // np.dtype(dtype).itemsize) - s[x].vectorize(ii) + # do not vectorize for custom data types + if dtype.count("custom"): + (io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // np.dtype(dtype).itemsize) + s[x].vectorize(ii) tvm.te.schedule.AutoInlineInjective(s) if not is_empty_shape(x.shape): diff --git a/python/tvm/topi/intel_graphics/conv2d_alter_op.py b/python/tvm/topi/intel_graphics/conv2d_alter_op.py index 3dc587e8710e..2e578d142593 100644 --- a/python/tvm/topi/intel_graphics/conv2d_alter_op.py +++ b/python/tvm/topi/intel_graphics/conv2d_alter_op.py @@ -22,7 +22,7 @@ from tvm import relay from tvm import autotvm -from ..utils import get_const_tuple +from ..utils import get_const_tuple, conv2d_infer_layout_helper from ..nn import conv2d_alter_layout, conv2d_infer_layout from .conv2d import _get_default_config @@ -102,14 +102,4 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): @conv2d_infer_layout.register("intel_graphics") def _conv2d_infer_layout(workload, cfg): - _, data, kernel, strides, padding, dilation, layout, dtype = workload - batch_size, in_channel, in_height, in_width = data[1] - out_channel, _, k_height, k_width = kernel[1] - out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1 - out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1 - tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic) - in_layout = f"NCHW{tile_ic}c" - out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc) - out_layout = f"NCHW{tile_oc}c" - return ((in_shape, in_layout),), ((out_shape, out_layout),) + return conv2d_infer_layout_helper(workload, cfg) diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py index d040310ccc8f..52b243e1f476 100644 --- a/python/tvm/topi/testing/common.py +++ b/python/tvm/topi/testing/common.py @@ -35,6 +35,8 @@ _reduce_schedule = { "generic": topi.generic.schedule_reduce, "cpu": topi.x86.schedule_reduce, + # TODO(@FranklandJack) Write arm_cpu specific reduction schedule. + "arm_cpu": topi.x86.schedule_reduce, "gpu": topi.cuda.schedule_reduce, "hls": topi.cuda.schedule_reduce, } diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 71599ad74a62..cbddd44d9133 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -24,6 +24,7 @@ import tvm from tvm import te from tvm.tir import Any, SizeVar, bijective_layout, layout +import tvm.topi from . import cpp, tag @@ -526,3 +527,26 @@ def is_target(names): def is_dynamic_shape(shape): """Checks if any part of a shape is dynamic""" return any([isinstance(x, (Any, SizeVar)) for x in shape]) + + +def conv2d_infer_layout_helper(workload, cfg): + """Infers input and output layouts for a conv2d operator + scheduled using "tile_ic" and "tile_oc" scheduling configuration knobs which + is the case for cpu, arm_cpu and intel_graphics targets.""" + _, data, kernel, strides, padding, dilation, _, _, _ = workload + batch_size, in_channel, in_height, in_width = data[1] + out_channel, _, k_height, k_width = kernel[1] + idxdiv = tvm.tir.indexdiv + + pt, pl, pb, pr = tvm.topi.nn.get_pad_tuple(padding, (k_height, k_width)) + hdilation, wdilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + dilated_kernel_h = (k_height - 1) * hdilation + 1 + dilated_kernel_w = (k_width - 1) * wdilation + 1 + out_height = idxdiv(in_height + pt + pb - dilated_kernel_h, strides[0]) + 1 + out_width = idxdiv(in_width + pl + pr - dilated_kernel_w, strides[1]) + 1 + tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic) + in_layout = f"NCHW{tile_ic}c" + out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc) + out_layout = f"NCHW{tile_oc}c" + return ((in_shape, in_layout),), ((out_shape, out_layout),) diff --git a/python/tvm/topi/x86/conv2d.py b/python/tvm/topi/x86/conv2d.py index 1b7f020d5014..d17fcbd905ef 100644 --- a/python/tvm/topi/x86/conv2d.py +++ b/python/tvm/topi/x86/conv2d.py @@ -30,7 +30,7 @@ from ..nn.conv2d import unpack_NCHWc_to_nchw from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.utils import get_pad_tuple -from ..utils import get_const_tuple, traverse_inline +from ..utils import get_const_tuple, traverse_inline, conv2d_infer_layout_helper from . import conv2d_avx_1x1, conv2d_avx_common logger = logging.getLogger("topi") @@ -65,23 +65,7 @@ def _get_default_config( @conv2d_infer_layout.register("cpu") def _conv2d_infer_layout(workload, cfg): - _, data, kernel, strides, padding, dilation, layout, _, dtype = workload - batch_size, in_channel, in_height, in_width = data[1] - out_channel, _, k_height, k_width = kernel[1] - idxdiv = tvm.tir.indexdiv - - pt, pl, pb, pr = get_pad_tuple(padding, (k_height, k_width)) - hdilation, wdilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) - dilated_kernel_h = (k_height - 1) * hdilation + 1 - dilated_kernel_w = (k_width - 1) * wdilation + 1 - out_height = idxdiv(in_height + pt + pb - dilated_kernel_h, strides[0]) + 1 - out_width = idxdiv(in_width + pl + pr - dilated_kernel_w, strides[1]) + 1 - tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic) - in_layout = f"NCHW{tile_ic}c" - out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc) - out_layout = f"NCHW{tile_oc}c" - return ((in_shape, in_layout),), ((out_shape, out_layout),) + return conv2d_infer_layout_helper(workload, cfg) def schedule_conv2d_nhwc(outs): diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 9af50d3f54ed..d9766d42025e 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -293,7 +293,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("opt-level") // LLVM command line flags, see below .add_attr_option>("cl-opt") - .set_default_keys({"cpu"}) + .set_default_keys({attr::kHostCPU}) // Force the external codegen kind attribute to be registered, even if no external // codegen targets are enabled by the TVM build. .set_attr(tvm::attr::kIsExternalCodegen, Bool(false)) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 43bf6b983eb5..7f9f128740ea 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -701,7 +702,15 @@ std::optional IsHostFunc(const PrimFunc& func) { if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) { return true; } else if (auto target = func->GetAttr(tvm::attr::kTarget)) { - return target.value()->HasKey("cpu"); + const auto keys = target.value()->keys; + const auto it = std::find_if(std::begin(keys), std::end(keys), + [](const String key) { return key.compare("cpu"); }); + if (std::end(keys) != it) { + const std::string key_string = *it; + return key_string == tvm::attr::kHostCPU; + } else { + return false; + } } else { return std::nullopt; } diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index d1d2b9902c60..149a00a767f2 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -222,7 +222,7 @@ def check_device(device): print("skip because %s is not enabled.." % device) return target = tvm.target.Target(device) - if "cpu" not in target.keys: + if not any(["cpu" in key for key in target.keys]): schedule[placeholder_b].bind(axis1, te.thread_axis("blockIdx.x")) schedule[placeholder_b].bind(axis2, te.thread_axis("threadIdx.x")) func = tvm.build(schedule, [placeholder_a, placeholder_b], device) diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index 3e63bc4751f7..3e223aa615c2 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -20,13 +20,17 @@ import tvm from tvm import relay from tvm import te +from tvm import target from tvm.relay.testing import run_infer_type import tvm.testing +native_arch = target.Target("llvm").keys[0] + + @pytest.mark.parametrize( "target, expected_implementation", - [("llvm", "concatenate.cpu"), ("llvm -device=arm_cpu", "concatenate.arm_cpu")], + [("llvm", "concatenate." + native_arch), ("llvm -device=arm_cpu", "concatenate.arm_cpu")], ) def test_concatenate(target, expected_implementation): target = tvm.target.Target(target) diff --git a/tests/python/topi/python/test_topi_bitserial_dense.py b/tests/python/topi/python/test_topi_bitserial_dense.py index 581de8ff98e5..5bf686819d42 100644 --- a/tests/python/topi/python/test_topi_bitserial_dense.py +++ b/tests/python/topi/python/test_topi_bitserial_dense.py @@ -17,6 +17,7 @@ """Test code for bitserial_dense operator""" import os import numpy as np +from tvm.target.target import Target import tvm from tvm import te from tvm import topi @@ -53,11 +54,12 @@ def get_ref_data(a_shape, b_shape, input_dtype): c_np = np.dot(a_np, b_np.T) return a_np, b_np, c_np - for target in ["llvm", "llvm -device=arm_cpu"]: - if "arm_cpu" in target and "arm" not in os.uname()[4]: + for target_string in ["llvm", "llvm -device=arm_cpu"]: + target = Target(target_string) + if "arm_cpu" in target.keys and "arm" not in os.uname()[4]: print("Skipped running code, not an arm device") continue - input_dtype = "uint8" if "arm_cpu" in target else "uint32" + input_dtype = "uint8" if "arm_cpu" in target.keys else "uint32" A = te.placeholder((batch, in_dim), dtype=input_dtype, name="A") B = te.placeholder((out_dim, in_dim), dtype=input_dtype, name="B") fcompute, fschedule = tvm.topi.testing.dispatch(target, _bitserial_dense_implement) diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index 9197a2097ebc..83432e588a5c 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -23,6 +23,7 @@ import tvm import tvm.testing from tvm import auto_scheduler +from tvm.target import Target from tvm.auto_scheduler.utils import get_const_tuple from tvm.testing.auto_scheduler import ( matmul_auto_scheduler_test, @@ -97,10 +98,7 @@ def test_search_task_record(): func="matmul_auto_scheduler_test", args=(N, N, N), target=target, - task_inputs={ - "test_input_0": test_input_0, - "test_input_1": test_input_1, - }, + task_inputs={"test_input_0": test_input_0, "test_input_1": test_input_1}, task_inputs_overwrite=True, ) task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) @@ -114,7 +112,12 @@ def test_search_task_record(): assert new_task.task_input_names[1] == "test_input_1" # Log with version 0.5 - v5_log = """["[\\\"matmul_auto_scheduler_test\\\", 64, 64, 64]", "llvm -keys=cpu", [6, 64, 64, 0, 0, 0, 0, 0], "", 1]""" + host_target_string = '"' + str(Target("llvm")) + '"' + v5_log = ( + """["[\\\"matmul_auto_scheduler_test\\\", 64, 64, 64]", """ + + host_target_string + + """, [6, 64, 64, 0, 0, 0, 0, 0], "", 1]""" + ) new_task = auto_scheduler._ffi_api.DeserializeSearchTask(v5_log) assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) @@ -148,9 +151,7 @@ def test_recover_measure_input_with_task_input(): func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm", - task_inputs={ - "test_input_0": test_input_0, - }, + task_inputs={"test_input_0": test_input_0}, task_inputs_overwrite=True, ) inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) @@ -171,10 +172,7 @@ def test_recover_measure_input_with_task_input(): func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm", - task_inputs={ - "test_input_0": test_input_0, - "test_input_1": test_input_1, - }, + task_inputs={"test_input_0": test_input_0, "test_input_1": test_input_1}, task_inputs_overwrite=True, ) inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) @@ -191,7 +189,12 @@ def test_recover_measure_input_with_task_input(): assert new_task.task_input_names[1] == "test_input_1" # Log with version 0.5 - v5_log = """{"i": [["[\\\"matmul_auto_scheduler_test\\\", 512, 512, 512]", "llvm -keys=cpu", [6, 64, 64, 0, 0, 0, 0, 0], "", 1], [[], []]], "r": [[0.1], 0, 0.2, 1], "v": "v0.6"}""" + host_target_string = '"' + str(Target("llvm")) + '"' + v5_log = ( + """{"i": [["[\\\"matmul_auto_scheduler_test\\\", 512, 512, 512]", """ + + host_target_string + + """, [6, 64, 64, 0, 0, 0, 0, 0], "", 1], [[], []]], "r": [[0.1], 0, 0.2, 1], "v": "v0.6"}""" + ) measure_log = auto_scheduler.measure_record.load_record_from_string(v5_log) new_task = measure_log[0].task assert task.workload_key == new_task.workload_key diff --git a/tests/python/unittest/test_autotvm_graph_tuner_core.py b/tests/python/unittest/test_autotvm_graph_tuner_core.py index bcc43648de22..e1aff8724178 100644 --- a/tests/python/unittest/test_autotvm_graph_tuner_core.py +++ b/tests/python/unittest/test_autotvm_graph_tuner_core.py @@ -148,6 +148,7 @@ def _create_data(target, dshape, dtype, layout): return net, records, ltf_records, ltf_keys, tasks +@tvm.testing.requires_x86 def test_graph_tuner_layout_transform(): log_file = "%s/test_tuner.log" % (os.getcwd()) target = "llvm" @@ -188,6 +189,7 @@ def test_graph_tuner_layout_transform(): ) +@tvm.testing.requires_x86 def test_graph_tuner_layout_transform_runner(): log_file = "%s/test_tuner.log" % (os.getcwd()) target = "llvm" @@ -231,6 +233,7 @@ def test_graph_tuner_layout_transform_runner(): ) +@tvm.testing.requires_x86 def test_DPTuner_run(): log_file = "%s/test_tuner.log" % (os.getcwd()) target = "llvm" @@ -295,6 +298,7 @@ def test_DPTuner_run(): assert os.path.isfile(log_file), "No log file with name %s exists." % log_file +@tvm.testing.requires_x86 def test_PBQPTuner_run(): target = "llvm" dtype = "float32" @@ -355,6 +359,7 @@ def test_PBQPTuner_run(): ) +@tvm.testing.requires_x86 def test_many_sub_graphs(): target = "llvm" dtype = "float32" @@ -517,6 +522,7 @@ def test_many_sub_graphs(): ) +@tvm.testing.requires_x86 def test_tuple(): target = "llvm" dtype = "float32" @@ -629,6 +635,7 @@ def test_tuple(): ) +@tvm.testing.requires_x86 def test_triangle_block(): target = "llvm" dtype = "float32" diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py b/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py index 2bfa3070d1b4..8534757e971d 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py @@ -23,6 +23,7 @@ from tvm import meta_schedule as ms from tvm.meta_schedule.schedule_rule import ApplyCustomRule from tvm.script import tir as T +from tvm.target import Target @tvm.script.ir_module @@ -42,9 +43,14 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] -@tvm.register_func("meta_schedule.cpu.test_apply_custom_rule") +native_target = Target("llvm -num-cores=1") +native_device_name = native_target.keys[0] +schedule_name = f"meta_schedule.{native_device_name}.test_apply_custom_rule" + + +@tvm.register_func(schedule_name) def sch_fn(sch: tvm.tir.Schedule, block: tvm.tir.Block) -> List[tvm.tir.Schedule]: - raise ValueError("Intended for meta_schedule.cpu.test_apply_custom_rule") + raise ValueError(f"Intended for {schedule_name}") def test_custom_rule(): @@ -54,12 +60,12 @@ def test_custom_rule(): space_gen = ms.space_generator.PostOrderApply(sch_rules=sch_rules) ms.tune_tir( mod=Matmul, - target="llvm -num-cores=1", + target=native_target, work_dir=tmpdir, max_trials_global=10, space=space_gen, ) - assert "ValueError: Intended for meta_schedule.cpu.test_apply_custom_rule" in str(e_info.value) + assert f"ValueError: Intended for {schedule_name}" in str(e_info.value) if __name__ == "__main__": diff --git a/tests/python/unittest/test_roofline.py b/tests/python/unittest/test_roofline.py index 7a1c3478c51d..e83c2029d03f 100644 --- a/tests/python/unittest/test_roofline.py +++ b/tests/python/unittest/test_roofline.py @@ -34,7 +34,7 @@ from tvm.script import tir as T -@tvm.testing.requires_llvm +@tvm.testing.requires_x86 @pytest.mark.parametrize("dtype", ["float32", "int8", "int32"]) def test_estimate_peak_flops_cpu(dtype): server = rpc.Server(key="roofline_flops_cpu") @@ -70,6 +70,7 @@ def test_estimate_peak_flops_gpu(): ), f"FLOP/s should be between 10^12 and 10^14, but it is {flops}" +@tvm.testing.requires_x86 @tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386") @tvm.testing.requires_llvm def test_estimate_peak_bandwidth_cpu(): @@ -101,6 +102,7 @@ def test_estimate_peak_bandwidth_gpu(): ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}" +@tvm.testing.requires_x86 @tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386") @tvm.testing.parametrize_targets("llvm -mattr=+fma,+avx2", "cuda") def test_roofline_analysis(target, dev): diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 0ed097ddf563..d8d46730a5b9 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -128,7 +128,8 @@ def test_cpu_get_graph_params_run(): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) -@tvm.testing.requires_llvm +# TODO(@franklandjack) Fix this test for the arm_cpu target. +@tvm.testing.requires_x86 def test_cpu_get_graph_params_compare(): # Create sample net from tvm.relay.testing.init import create_workload, Constant