diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index f52bbc36f12b..c52da541a8ab 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -17,8 +17,10 @@ """Definition of ROCm operator strategy.""" # pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import from tvm import topi +from tvm.auto_scheduler import is_auto_scheduler_enabled from .generic import * from .. import op as _op +from .cuda import judge_winograd, naive_schedule @schedule_lrn.register("rocm") @@ -67,6 +69,49 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): name="conv2d_nchw_winograd.cuda", plevel=5, ) + elif layout == "NHWC": + assert kernel_layout == "HWIO" + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nhwc), + wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), + name="conv2d_nhwc.cuda", + ) + N, H, W, _ = get_const_tuple(data.shape) + KH, KW, CI, CO = get_const_tuple(kernel.shape) + + (_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd( + N, + H, + W, + KH, + KW, + CI, + CO, + padding, + stride_h, + stride_w, + dilation_h, + dilation_w, + data.dtype, + kernel.dtype, + pre_flag=False, + ) + + if judge_winograd_autotvm: + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct), + wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct), + name="conv2d_nhwc_winograd_direct.cuda", + plevel=5, + ) + + if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler: + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nhwc.winograd", + plevel=15, + ) elif layout == "HWCN": assert kernel_layout == "HWIO" strategy.add_implementation( @@ -74,13 +119,6 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn), name="conv2d_hwcn.cuda", ) - # TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda - # elif layout == "NHWC": - # assert kernel_layout == "HWIO" - # strategy.add_implementation( - # wrap_compute_conv2d(topi.cuda.conv2d_nhwc), - # wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), - # name="conv2d_nhwc.cuda") elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: assert kernel_layout == "OIHW4o4i" strategy.add_implementation( diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 0a3d705d9898..53287a0eddeb 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -41,6 +41,7 @@ #include #include +#include "search_policy/utils.h" #include "utils.h" namespace tvm { @@ -1296,7 +1297,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i } auto mod = IRModule(Map({{global_var, f}})); - if (task->target->kind->device_type == kDLGPU) { + if (IsGPUTask(task)) { auto pass_list = Array(); // Phase 0 pass_list.push_back(tir::transform::InjectPrefetch()); diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 4c8cc6d70ac8..5a3475542878 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -22,6 +22,7 @@ * \brief Meta information and hardware parameters for a search task. */ +#include #include #include #include @@ -52,11 +53,13 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_l HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target, const Target& target_host) { - if (target->kind->device_type == kDLCPU) { + const auto device_type = target->kind->device_type; + if (device_type == kDLCPU) { return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0); - } else if (target->kind->device_type == kDLGPU) { - auto ctx = TVMContext{kDLGPU, 0}; - auto func = tvm::runtime::Registry::Get("device_api.gpu"); + } else if (device_type == kDLGPU || device_type == kDLROCM) { + auto ctx = TVMContext{static_cast(device_type), 0}; + auto device_name = device_type == kDLGPU ? "device_api.gpu" : "device_api.rocm"; + auto func = tvm::runtime::Registry::Get(device_name); ICHECK(func != nullptr) << "Cannot find GPU device_api in registry"; auto device_api = static_cast(((*func)()).operator void*()); @@ -77,7 +80,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target int max_vthread_extent = warp_size / 4; return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, max_threads_per_block, max_vthread_extent, warp_size); - } else if (target->kind->device_type == kDLMetal) { + } else if (device_type == kDLMetal) { // Reference: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf // This setting looks working for Metal GPUs later than A10 int max_shared_memory_per_block = 32 * 1024;