Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_na
scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const {
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int device_id = t->context.device_id;
if (scache_[device_id] == nil) {
Expand All @@ -197,7 +197,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const
}
if (num_pack_args_ != 0) {
[encoder setBytes:pack_args
length:num_pack_args_ * sizeof(ArgUnion)
length:num_pack_args_ * sizeof(ArgUnion64)
atIndex:num_buffer_args_];
}
// launch
Expand Down
36 changes: 25 additions & 11 deletions src/runtime/pack_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,24 @@ namespace tvm {
namespace runtime {
/*!
* \brief argument union type of 32bit.
* Choose 32 bit because most GPU API do not work well with 64 bit.
*/
union ArgUnion {
union ArgUnion32 {
int32_t v_int32;
uint32_t v_uint32;
float v_float32;
};

/*!
* \brief argument union type of 64 bit, for use by Vulkan and Metal runtime.
*/
union ArgUnion64 {
int32_t v_int32[2];
uint32_t v_uint32[2];
float v_float32[2];
int64_t v_int64;
uint64_t v_uint64;
double v_float64;
};
/*!
* \brief Create a packed function from void addr types.
*
Expand Down Expand Up @@ -140,9 +151,9 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<void*, N> addr_(num_args);
TempArray<ArgUnion, N> holder_(num_args);
TempArray<ArgUnion32, N> holder_(num_args);
void** addr = addr_.data();
ArgUnion* holder = holder_.data();
ArgUnion32* holder = holder_.data();
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case INT64_TO_INT64:
Expand Down Expand Up @@ -177,25 +188,28 @@ template <int N, typename F>
inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<ArgUnion, N> holder_(num_args);
ArgUnion* holder = holder_.data();
TempArray<ArgUnion64, N> holder_(num_args);
ArgUnion64* holder = holder_.data();
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case INT64_TO_INT64:
case INT64_TO_INT64: {
holder[i].v_int64 = args.values[base + i].v_int64;
break;
}
case FLOAT64_TO_FLOAT64: {
LOG(FATAL) << "Do not support 64bit argument to device function";
holder[i].v_float64 = args.values[base + i].v_float64;
break;
}
case INT64_TO_INT32: {
holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
holder[i].v_int32[0] = static_cast<int32_t>(args.values[base + i].v_int64);
break;
}
case INT64_TO_UINT32: {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
holder[i].v_uint32[0] = static_cast<uint32_t>(args.values[base + i].v_int64);
break;
}
case FLOAT64_TO_FLOAT32: {
holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
holder[i].v_float32[0] = static_cast<float>(args.values[base + i].v_float64);
break;
}
case HANDLE_TO_HANDLE: {
Expand Down
14 changes: 8 additions & 6 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class VulkanWrappedFunc {
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
}

void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const;

private:
// internal module
Expand Down Expand Up @@ -875,7 +875,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
VkPushConstantRange crange;
crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
crange.offset = 0;
crange.size = sizeof(ArgUnion) * num_pack_args;
crange.size = sizeof(ArgUnion64) * num_pack_args;

VkPipelineLayoutCreateInfo playout_cinfo;
playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
Expand Down Expand Up @@ -1046,7 +1046,8 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
return streams_[device_id].get();
}

void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
const ArgUnion64* pack_args) const {
int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
ICHECK_LT(device_id, kVulkanMaxNumDevice);
const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
Expand Down Expand Up @@ -1075,7 +1076,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
descriptor_buffers.data());
if (num_pack_args_ != 0) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
pack_args);
}
vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
Expand All @@ -1093,7 +1094,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
}

// Otherwise, the more expensive deferred path.
std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_);
const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
std::vector<VkWriteDescriptorSet> write_descriptor_sets;
write_descriptor_sets.resize(descriptor_buffers.size());
Expand All @@ -1119,7 +1120,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
nullptr);
if (pack_args_storage.size() != 0) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
0, pack_args_storage.size() * sizeof(ArgUnion64),
pack_args_storage.data());
}
vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
VkMemoryBarrier barrier_info;
Expand Down
7 changes: 6 additions & 1 deletion src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ CodeGenMetal::CodeGenMetal() {
decl_stream << "#include <metal_stdlib>\n";
decl_stream << "using namespace metal;\n\n";
decl_stream << "union __TVMArgUnion {\n"
<< " int v_int;\n"
<< " int v_int[2];\n"
<< "};\n\n";
}

Expand Down Expand Up @@ -102,6 +102,11 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
std::string vid = AllocVarID(v.get());
std::ostringstream vref;
if (v.dtype().bits() == 32) {
decl_stream << " ";
PrintType(v.dtype(), decl_stream);
decl_stream << " " << vid << "[2];\n";
vref << varg << "." << vid << "[0]";
} else if (v.dtype().bits() == 64) {
decl_stream << " ";
PrintType(v.dtype(), decl_stream);
decl_stream << " " << vid << ";\n";
Expand Down
6 changes: 6 additions & 0 deletions src/target/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,14 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntr

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sin").set_body(DispatchGLSLPureIntrin<GLSLstd450Sin>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.cos").set_body(DispatchGLSLPureIntrin<GLSLstd450Cos>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log2").set_body(DispatchGLSLPureIntrin<GLSLstd450Log2>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
Expand Down
7 changes: 7 additions & 0 deletions tests/python/topi/python/test_topi_cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def check_cumsum(np_ref, data, axis=None, dtype=None):
"generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern),
"cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
"nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
"vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
"metal": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
}
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
Expand All @@ -44,6 +46,9 @@ def check_cumsum(np_ref, data, axis=None, dtype=None):
check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")

for in_dtype in ["float32", "float64"]:
if target == "metal" and in_dtype == "float64":
# float64 is not supported in metal
continue
data = np.random.randn(10, 10).astype(in_dtype)
check_cumsum(np.cumsum(data), data)
check_cumsum(np.cumsum(data, axis=0), data, axis=0)
Expand All @@ -70,3 +75,5 @@ def check_cumsum(np_ref, data, axis=None, dtype=None):
test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))
test_cumsum(tvm.context("metal"), tvm.target.Target("metal"))
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def check_device(device):
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3)

for device in ["llvm", "cuda", "opencl"]:
for device in ["llvm", "cuda", "opencl", "vulkan"]:
check_device(device)


Expand Down