diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 48328263fb55..2cb1269d010f 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -272,10 +272,11 @@ PrimFunc Specialize(PrimFunc func, const Map& param_map); * \sa tvm::attr */ namespace attr { + /*! * \brief List of thread IterVar that a DeviceLaunch function corresponds to. * - * Type: Array + * Type: Array * * We call a device kernel launch function f using the following convention: * @@ -283,23 +284,42 @@ namespace attr { * [arg1, arg2, ..., arg_n, * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size]) * - * Here n = len(arg), m = len(work_size) = len(device_thread_axis). + * Here n = len(arg), m = len(work_size) = len(launch_params)-1. * - * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted. + * The list of kernel launch params indicates which additional + * parameters will be provided to the PackedFunc by the calling + * scope. * - * The list of device_thread_axis indicates how can be bind the - * work_size arguments to the corresponding threads. + * - "threadIdx.x", "threadIdx.y", "threadIdx.z" * - * \sa tvm::CallingConv::kDeviceKernelLaunch - */ -constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; - -/*! - * \brief Whether or not use dynamic shared memory. + * The extent of the thread count in x/y/z, to be used when + * launching the compute kernel on the device. For example, the + * gridDimX/Y/Z parameters passed to cuLaunchKernel when launching a + * CUDA kernel, or the groupCountX/Y/Z parameters passed to + * vkCmdDispatch when dispatching a compute pipeline to Vulkan. * - * Type: Integer + * - "blockIdx.x", "blockIdx.y", "blockIdx.z" + * + * The extent of the block iterators, to be used when launching the + * compute kernel on the device. For example, the blockDimX/Y/Z + * parameters passed to cuLaunchKernel when launching a CUDA kernel. + * For runtimes that do not require the block to be provided + * externally, this parameter is ignored. For example, the + * spv::ExecutionModeLocalSize for SPIR-V shaders on Vulkan, where + * this parameter is defined in the shader. + * + * - tvm::runtime::launch_param::kUseDynamicSharedMemoryTag + * + * The size of the shared memory that may be allocated internally by + * the kernel. For example, exposed as the + * CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES attribute in + * cuda. + * + * Defined as "tir.use_dyn_shared_memory". + * + * \sa tvm::CallingConv::kDeviceKernelLaunch */ -constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory"; +constexpr const char* kKernelLaunchParams = "tir.kernel_launch_params"; /*! * \brief Whether to set noalias rule on the function arguments. diff --git a/src/target/build_common.h b/src/target/build_common.h index 35b3d92eb814..7c9ad8cb3c68 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -50,15 +50,9 @@ inline std::unordered_map ExtractFuncInfo(co for (size_t i = 0; i < f->params.size(); ++i) { info.arg_types.push_back(f->params[i].dtype()); } - if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { - auto thread_axis = opt.value(); - for (size_t i = 0; i < thread_axis.size(); ++i) { - info.launch_param_tags.push_back(thread_axis[i]->thread_tag); - } - } - if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { - if (opt.value().IntValue() != 0) { - info.launch_param_tags.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); + if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + for (const auto& tag : opt.value()) { + info.launch_param_tags.push_back(tag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 534e2c3654c4..36ef44bc4814 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -130,12 +130,14 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); - - for (IterVar iv : thread_axis) { - runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); - work_dim = std::max(work_dim, scope.dim_index + 1); + auto launch_params = f->GetAttr>(tir::attr::kKernelLaunchParams).value(); + for (const auto& tag : launch_params) { + if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { + runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); + work_dim = std::max(work_dim, scope.dim_index + 1); + } } + if (work_dim != 0) { // use ushort by default for now stream << " "; @@ -145,16 +147,6 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); stream << " threadIdx [[thread_position_in_threadgroup]]\n"; } - // bind thread axis - for (IterVar iv : thread_axis) { - ICHECK(!var_idmap_.count(iv->var.get())); - std::string vname = iv->thread_tag; - if (work_dim <= 1) { - vname = vname.substr(0, iv->thread_tag.length() - 2); - } - var_idmap_[iv->var.get()] = - CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); - } // the function scope. stream << ") {\n"; int func_scope = this->BeginScope(); diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index e101e6b904ce..519a3e1f80d8 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -69,26 +69,29 @@ class ThreadAxisRewriter : private StmtExprMutator { std::unordered_map vmap_; }; -PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { +PrimFunc RemapThreadAxis(PrimFunc func, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { tmap[kv.first] = kv.second; } - auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - ICHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis; - auto thread_axis = opt_thread_axis.value(); - auto* n = f.CopyOnWrite(); - - // replace the thread axis - for (size_t i = 0; i < thread_axis.size(); ++i) { - auto it = tmap.find(thread_axis[i]->thread_tag); - if (it != tmap.end()) { - thread_axis.Set(i, it->second); + if (auto opt = func->GetAttr>(tir::attr::kKernelLaunchParams)) { + ICHECK(opt != nullptr) << "Require attribute " << tir::attr::kKernelLaunchParams; + auto launch_params = opt.value(); + // replace the thread axis attribute + for (size_t i = 0; i < launch_params.size(); ++i) { + auto it = tmap.find(launch_params[i]->thread_tag); + if (it != tmap.end()) { + launch_params.Set(i, it->second); + } } + + func = WithAttr(std::move(func), tir::attr::kKernelLaunchParams, launch_params); } + + auto* n = func.CopyOnWrite(); n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body)); - return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis); + return func; } namespace transform { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index c43fc403ed94..3696ff84e5b8 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -51,6 +51,17 @@ class DeviceInfoCollector : public StmtVisitor { PrimExpr dyn_shmem_size_{0}; bool use_dyn_shmem_{false}; + Array GetLaunchParams() const { + Array output; + for (const auto& axis : thread_axis_) { + output.push_back(axis->thread_tag); + } + if (use_dyn_shmem_) { + output.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); + } + return output; + } + private: void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { @@ -199,8 +210,9 @@ class HostDeviceSplitter : public StmtMutator { GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false); PrimFunc device_func(params, Substitute(body, remap_vars)); - device_func = - WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, dev_info.thread_axis_); + device_func = WithAttr(std::move(device_func), tir::attr::kKernelLaunchParams, + dev_info.GetLaunchParams()); + device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, @@ -208,10 +220,7 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); - if (dev_info.use_dyn_shmem_) { - device_func = - WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); - } + (*device_mod_)->Add(kernel_symbol_global, device_func); // generate calls to the device function