From 9c56f14eb7cf244ca1b0cf3da110d38eb428071b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 3 Mar 2021 07:37:52 +0900 Subject: [PATCH 01/13] introduce ArgUnion64 --- src/runtime/pack_args.h | 26 +++++++++++++++++++------- src/runtime/vulkan/vulkan.cc | 12 ++++++------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 45cde22bda08..54a75d62a2e9 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -47,6 +47,15 @@ union ArgUnion { uint32_t v_uint32; float v_float32; }; + +union ArgUnion64 { + int32_t v_int32[2]; + uint32_t v_uint32[2]; + float v_float32[2]; + int64_t v_int64; + uint64_t v_uint64; + double v_float64; +}; /*! * \brief Create a packed function from void addr types. * @@ -177,25 +186,28 @@ template inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) { - TempArray holder_(num_args); - ArgUnion* holder = holder_.data(); + TempArray holder_(num_args); + ArgUnion64* holder = holder_.data(); for (int i = 0; i < num_args; ++i) { switch (codes[i]) { - case INT64_TO_INT64: + case INT64_TO_INT64: { + holder[i].v_int64 = args.values[base + i].v_int64; + break; + } case FLOAT64_TO_FLOAT64: { - LOG(FATAL) << "Do not support 64bit argument to device function"; + holder[i].v_float64 = args.values[base + i].v_float64; break; } case INT64_TO_INT32: { - holder[i].v_int32 = static_cast(args.values[base + i].v_int64); + holder[i].v_int32[0] = static_cast(args.values[base + i].v_int64); break; } case INT64_TO_UINT32: { - holder[i].v_uint32 = static_cast(args.values[base + i].v_int64); + holder[i].v_uint32[0] = static_cast(args.values[base + i].v_int64); break; } case FLOAT64_TO_FLOAT32: { - holder[i].v_float32 = static_cast(args.values[base + i].v_float64); + holder[i].v_float32[0] = static_cast(args.values[base + i].v_float64); break; } case HANDLE_TO_HANDLE: { diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index f40fd80f38b5..4eb34819f40b 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -711,7 +711,7 @@ class VulkanWrappedFunc { thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); } - void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const; + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; private: // internal module @@ -875,7 +875,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { VkPushConstantRange crange; crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; crange.offset = 0; - crange.size = sizeof(ArgUnion) * num_pack_args; + crange.size = sizeof(ArgUnion64) * num_pack_args; VkPipelineLayoutCreateInfo playout_cinfo; playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; @@ -1046,7 +1046,7 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { return streams_[device_id].get(); } -void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { +void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id; ICHECK_LT(device_id, kVulkanMaxNumDevice); const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); @@ -1075,7 +1075,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion descriptor_buffers.data()); if (num_pack_args_ != 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, - VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion), + VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), pack_args); } vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); @@ -1093,7 +1093,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion } // Otherwise, the more expensive deferred path. - std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); + std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() { std::vector write_descriptor_sets; write_descriptor_sets.resize(descriptor_buffers.size()); @@ -1119,7 +1119,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion nullptr); if (pack_args_storage.size() != 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, - 0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data()); + 0, pack_args_storage.size() * sizeof(ArgUnion64), pack_args_storage.data()); } vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); VkMemoryBarrier barrier_info; From 5c2640c7b14ff7dcc8a1c64f3c519adfd3219d11 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 3 Mar 2021 07:38:02 +0900 Subject: [PATCH 02/13] add missing intrinsic --- src/target/spirv/intrin_rule_spirv.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 90b2eb2a671f..b75fb53b150d 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -62,8 +62,14 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntr TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sin").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.cos").set_body(DispatchGLSLPureIntrin); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log2").set_body(DispatchGLSLPureIntrin); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin); From 59e62f3373d8b1363e8a61e1ab05a7cdfb7380ec Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 3 Mar 2021 08:18:24 +0900 Subject: [PATCH 03/13] test cumsum on vulkan --- tests/python/topi/python/test_topi_cumsum.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index a01a496f92e9..bf962d93fab3 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -28,6 +28,7 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern), "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), + "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule) @@ -40,8 +41,10 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): check_cumsum(np.cumsum(data, dtype=np.int32), data) check_cumsum(np.cumsum(data), data, dtype="int64") - data = np.random.rand(10) > 0.5 - check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") + if str(target.kind) != "vulkan": + # TODO(masahi): Support bool tensor in SPIRV codegen + data = np.random.rand(10) > 0.5 + check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") for in_dtype in ["float32", "float64"]: data = np.random.randn(10, 10).astype(in_dtype) @@ -70,3 +73,4 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda")) test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx")) + test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan")) From 53eaa1fe0407a201184fd8394c8d0670b52632b3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 3 Mar 2021 08:23:16 +0900 Subject: [PATCH 04/13] update metal runtime to use ArgUnion64 (not tested) --- src/runtime/metal/metal_module.mm | 4 ++-- src/runtime/pack_args.h | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 981dd6129f9e..8f1fde86f074 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -180,7 +180,7 @@ void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_na scache_[dev_id] = m->GetPipelineState(dev_id, func_name); } // invoke the function with void arguments - void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->context.device_id; if (scache_[device_id] == nil) { @@ -197,7 +197,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const } if (num_pack_args_ != 0) { [encoder setBytes:pack_args - length:num_pack_args_ * sizeof(ArgUnion) + length:num_pack_args_ * sizeof(ArgUnion64) atIndex:num_buffer_args_]; } // launch diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 54a75d62a2e9..2e7a8814bf74 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -40,7 +40,6 @@ namespace tvm { namespace runtime { /*! * \brief argument union type of 32bit. - * Choose 32 bit because most GPU API do not work well with 64 bit. */ union ArgUnion { int32_t v_int32; @@ -48,6 +47,9 @@ union ArgUnion { float v_float32; }; +/*! + * \brief argument union type of 64 bit, for use by Vulkan and Metal runtime. + */ union ArgUnion64 { int32_t v_int32[2]; uint32_t v_uint32[2]; From 71ef001b35a0c10663f93d2e884471f38487400f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 3 Mar 2021 08:25:08 +0900 Subject: [PATCH 05/13] test get_valid_counts on vulkan --- tests/python/topi/python/test_topi_vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 839356892ab1..2fdf3cf4b170 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -112,7 +112,7 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) - for device in ["llvm", "cuda", "opencl"]: + for device in ["llvm", "cuda", "opencl", "vulkan"]: check_device(device) From 41bfd0257386f44a369a1e8bf04a2173140577cc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 3 Mar 2021 09:50:43 +0900 Subject: [PATCH 06/13] formatting --- src/runtime/vulkan/vulkan.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 4eb34819f40b..794f3c570f96 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -1046,7 +1046,8 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { return streams_[device_id].get(); } -void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { +void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, + const ArgUnion64* pack_args) const { int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id; ICHECK_LT(device_id, kVulkanMaxNumDevice); const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); @@ -1119,7 +1120,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion nullptr); if (pack_args_storage.size() != 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, - 0, pack_args_storage.size() * sizeof(ArgUnion64), pack_args_storage.data()); + 0, pack_args_storage.size() * sizeof(ArgUnion64), + pack_args_storage.data()); } vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); VkMemoryBarrier barrier_info; From a22aa3f0d403b1f6cb7aafad10871a7df3388ef0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 3 Mar 2021 18:14:23 +0900 Subject: [PATCH 07/13] pytest fix --- tests/python/topi/python/test_topi_cumsum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index bf962d93fab3..6b99239cc007 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -41,7 +41,7 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): check_cumsum(np.cumsum(data, dtype=np.int32), data) check_cumsum(np.cumsum(data), data, dtype="int64") - if str(target.kind) != "vulkan": + if target != "vulkan": # TODO(masahi): Support bool tensor in SPIRV codegen data = np.random.rand(10) > 0.5 check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") From 0f944175289c35079e57fc51ca14e65851ddeca2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 4 Mar 2021 04:35:20 +0900 Subject: [PATCH 08/13] ArgUnion -> ArgUnion32 --- src/runtime/pack_args.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 2e7a8814bf74..7c852da77df6 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -41,7 +41,7 @@ namespace runtime { /*! * \brief argument union type of 32bit. */ -union ArgUnion { +union ArgUnion32 { int32_t v_int32; uint32_t v_uint32; float v_float32; @@ -151,9 +151,9 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& code int num_args = static_cast(codes.size()); auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { TempArray addr_(num_args); - TempArray holder_(num_args); + TempArray holder_(num_args); void** addr = addr_.data(); - ArgUnion* holder = holder_.data(); + ArgUnion32* holder = holder_.data(); for (int i = 0; i < num_args; ++i) { switch (codes[i]) { case INT64_TO_INT64: From f6c83b101fd958b63076e7b510032b91b6cff3f4 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 4 Mar 2021 16:14:42 -0500 Subject: [PATCH 09/13] Update metal codegen for ArgUnion64 --- src/target/source/codegen_metal.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index baa30065a7f9..f7219cbb6abe 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -47,7 +47,7 @@ CodeGenMetal::CodeGenMetal() { decl_stream << "#include \n"; decl_stream << "using namespace metal;\n\n"; decl_stream << "union __TVMArgUnion {\n" - << " int v_int;\n" + << " int v_long;\n" << "};\n\n"; } @@ -104,8 +104,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { if (v.dtype().bits() == 32) { decl_stream << " "; PrintType(v.dtype(), decl_stream); - decl_stream << " " << vid << ";\n"; - vref << varg << "." << vid; + decl_stream << " " << vid << "[2];\n"; + vref << varg << "." << vid << "[0]"; } else { // For non 32bit type, ref through arg union. decl_stream << " __TVMArgUnion " << vid << ";\n"; From 630a3fb9df53d33516ab9d85acc8db44474c0deb Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 4 Mar 2021 17:01:42 -0500 Subject: [PATCH 10/13] Add explici 64-bit support in metal codegen --- src/target/source/codegen_metal.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index f7219cbb6abe..c95d578df686 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -47,7 +47,7 @@ CodeGenMetal::CodeGenMetal() { decl_stream << "#include \n"; decl_stream << "using namespace metal;\n\n"; decl_stream << "union __TVMArgUnion {\n" - << " int v_long;\n" + << " int v_int[2];\n" << "};\n\n"; } @@ -106,6 +106,11 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { PrintType(v.dtype(), decl_stream); decl_stream << " " << vid << "[2];\n"; vref << varg << "." << vid << "[0]"; + } else if (v.dtype().bits() == 64) { + decl_stream << " "; + PrintType(v.dtype(), decl_stream); + decl_stream << " " << vid << ";\n"; + vref << varg << "." << vid; } else { // For non 32bit type, ref through arg union. decl_stream << " __TVMArgUnion " << vid << ";\n"; From 76bcd217f03540b46c83cd7d1d76151b1fd6cc82 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 4 Mar 2021 17:03:35 -0500 Subject: [PATCH 11/13] add test --- tests/python/topi/python/test_topi_cumsum.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index 6b99239cc007..79330e7063db 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -29,6 +29,7 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), + "metal": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule) @@ -47,6 +48,9 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") for in_dtype in ["float32", "float64"]: + if str(target.kind) == 'metal' and in_dtype == 'float64': + # float64 is not supported in metal + continue data = np.random.randn(10, 10).astype(in_dtype) check_cumsum(np.cumsum(data), data) check_cumsum(np.cumsum(data, axis=0), data, axis=0) @@ -74,3 +78,4 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda")) test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx")) test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan")) + test_cumsum(tvm.context("metal"), tvm.target.Target("metal")) From 1a4076b5cac17727c89035823a5b4a2df95658c9 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 4 Mar 2021 17:21:46 -0500 Subject: [PATCH 12/13] update test target --- tests/python/topi/python/test_topi_cumsum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index 79330e7063db..eaf46c07ff4b 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -48,7 +48,7 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") for in_dtype in ["float32", "float64"]: - if str(target.kind) == 'metal' and in_dtype == 'float64': + if target == 'metal' and in_dtype == 'float64': # float64 is not supported in metal continue data = np.random.randn(10, 10).astype(in_dtype) From 940c95179ec65642ec26cbe69d789ac67ae788da Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 5 Mar 2021 10:25:10 +0900 Subject: [PATCH 13/13] enable boolean input cumsum test on vk --- tests/python/topi/python/test_topi_cumsum.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index eaf46c07ff4b..cfe5130643c5 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -42,13 +42,11 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): check_cumsum(np.cumsum(data, dtype=np.int32), data) check_cumsum(np.cumsum(data), data, dtype="int64") - if target != "vulkan": - # TODO(masahi): Support bool tensor in SPIRV codegen - data = np.random.rand(10) > 0.5 - check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") + data = np.random.rand(10) > 0.5 + check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") for in_dtype in ["float32", "float64"]: - if target == 'metal' and in_dtype == 'float64': + if target == "metal" and in_dtype == "float64": # float64 is not supported in metal continue data = np.random.randn(10, 10).astype(in_dtype)