Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,34 +272,54 @@ PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);
* \sa tvm::attr
*/
namespace attr {

/*!
* \brief List of thread IterVar that a DeviceLaunch function corresponds to.
*
* Type: Array<tir::IterVar>
* Type: Array<String>
*
* We call a device kernel launch function f using the following convention:
*
* Call(f,
* [arg1, arg2, ..., arg_n,
* work_size_1, work_size_2, ... work_size_m, dyn_shmem_size])
*
* Here n = len(arg), m = len(work_size) = len(device_thread_axis).
* Here n = len(arg), m = len(work_size) = len(launch_params)-1.
*
* When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted.
* The list of kernel launch params indicates which additional
* parameters will be provided to the PackedFunc by the calling
* scope.
*
* The list of device_thread_axis indicates how can be bind the
* work_size arguments to the corresponding threads.
* - "threadIdx.x", "threadIdx.y", "threadIdx.z"
*
* \sa tvm::CallingConv::kDeviceKernelLaunch
*/
constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";

/*!
* \brief Whether or not use dynamic shared memory.
* The extent of the thread count in x/y/z, to be used when
* launching the compute kernel on the device. For example, the
* gridDimX/Y/Z parameters passed to cuLaunchKernel when launching a
* CUDA kernel, or the groupCountX/Y/Z parameters passed to
* vkCmdDispatch when dispatching a compute pipeline to Vulkan.
*
* Type: Integer
* - "blockIdx.x", "blockIdx.y", "blockIdx.z"
*
* The extent of the block iterators, to be used when launching the
* compute kernel on the device. For example, the blockDimX/Y/Z
* parameters passed to cuLaunchKernel when launching a CUDA kernel.
* For runtimes that do not require the block to be provided
* externally, this parameter is ignored. For example, the
* spv::ExecutionModeLocalSize for SPIR-V shaders on Vulkan, where
* this parameter is defined in the shader.
*
* - tvm::runtime::launch_param::kUseDynamicSharedMemoryTag
*
* The size of the shared memory that may be allocated internally by
* the kernel. For example, exposed as the
* CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES attribute in
* cuda.
*
* Defined as "tir.use_dyn_shared_memory".
*
* \sa tvm::CallingConv::kDeviceKernelLaunch
*/
constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory";
constexpr const char* kKernelLaunchParams = "tir.kernel_launch_params";

/*!
* \brief Whether to set noalias rule on the function arguments.
Expand Down
12 changes: 3 additions & 9 deletions src/target/build_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,9 @@ inline std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
for (size_t i = 0; i < f->params.size(); ++i) {
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis)) {
auto thread_axis = opt.value();
for (size_t i = 0; i < thread_axis.size(); ++i) {
info.launch_param_tags.push_back(thread_axis[i]->thread_tag);
}
}
if (auto opt = f->GetAttr<Integer>(tir::attr::kDeviceUseDynSharedMemory)) {
if (opt.value().IntValue() != 0) {
info.launch_param_tags.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag);
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
Expand Down
22 changes: 7 additions & 15 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,14 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
int work_dim = 0;
auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis).value();

for (IterVar iv : thread_axis) {
runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
work_dim = std::max(work_dim, scope.dim_index + 1);
auto launch_params = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
for (const auto& tag : launch_params) {
if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) {
runtime::ThreadScope scope = runtime::ThreadScope::Create(tag);
work_dim = std::max(work_dim, scope.dim_index + 1);
}
}

if (work_dim != 0) {
// use ushort by default for now
stream << " ";
Expand All @@ -145,16 +147,6 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
stream << " threadIdx [[thread_position_in_threadgroup]]\n";
}
// bind thread axis
for (IterVar iv : thread_axis) {
ICHECK(!var_idmap_.count(iv->var.get()));
std::string vname = iv->thread_tag;
if (work_dim <= 1) {
vname = vname.substr(0, iv->thread_tag.length() - 2);
}
var_idmap_[iv->var.get()] =
CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype());
}
// the function scope.
stream << ") {\n";
int func_scope = this->BeginScope();
Expand Down
27 changes: 15 additions & 12 deletions src/tir/transforms/remap_thread_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,29 @@ class ThreadAxisRewriter : private StmtExprMutator {
std::unordered_map<const VarNode*, Var> vmap_;
};

PrimFunc RemapThreadAxis(PrimFunc&& f, Map<runtime::String, IterVar> thread_map) {
PrimFunc RemapThreadAxis(PrimFunc func, Map<runtime::String, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
tmap[kv.first] = kv.second;
}

auto opt_thread_axis = f->GetAttr<Array<IterVar>>(tir::attr::kDeviceThreadAxis);
ICHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis;
auto thread_axis = opt_thread_axis.value();
auto* n = f.CopyOnWrite();

// replace the thread axis
for (size_t i = 0; i < thread_axis.size(); ++i) {
auto it = tmap.find(thread_axis[i]->thread_tag);
if (it != tmap.end()) {
thread_axis.Set(i, it->second);
if (auto opt = func->GetAttr<Array<IterVar>>(tir::attr::kKernelLaunchParams)) {
ICHECK(opt != nullptr) << "Require attribute " << tir::attr::kKernelLaunchParams;
auto launch_params = opt.value();
// replace the thread axis attribute
for (size_t i = 0; i < launch_params.size(); ++i) {
auto it = tmap.find(launch_params[i]->thread_tag);
if (it != tmap.end()) {
launch_params.Set(i, it->second);
}
}

func = WithAttr(std::move(func), tir::attr::kKernelLaunchParams, launch_params);
}

auto* n = func.CopyOnWrite();
n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body));
return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis);
return func;
}

namespace transform {
Expand Down
21 changes: 15 additions & 6 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ class DeviceInfoCollector : public StmtVisitor {
PrimExpr dyn_shmem_size_{0};
bool use_dyn_shmem_{false};

Array<String> GetLaunchParams() const {
Array<String> output;
for (const auto& axis : thread_axis_) {
output.push_back(axis->thread_tag);
}
if (use_dyn_shmem_) {
output.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag);
}
return output;
}

private:
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
Expand Down Expand Up @@ -199,19 +210,17 @@ class HostDeviceSplitter : public StmtMutator {
GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false);

PrimFunc device_func(params, Substitute(body, remap_vars));
device_func =
WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, dev_info.thread_axis_);
device_func = WithAttr(std::move(device_func), tir::attr::kKernelLaunchParams,
dev_info.GetLaunchParams());

device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
Integer(CallingConv::kDeviceKernelLaunch));
device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
runtime::String(kernel_symbol_global->name_hint));
device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1));
device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_);
device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1));
if (dev_info.use_dyn_shmem_) {
device_func =
WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1));
}

(*device_mod_)->Add(kernel_symbol_global, device_func);

// generate calls to the device function
Expand Down