diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 97241b05514e..ef280a749580 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -621,6 +622,12 @@ VulkanDeviceAPI::VulkanDeviceAPI() { } return extensions; }(); + + // All TVM-generated spirv shaders are marked as requiring int64 + // support, so we need to request it from the device, too. + VkPhysicalDeviceFeatures enabled_features = {}; + enabled_features.shaderInt64 = VK_TRUE; + VkDeviceCreateInfo device_create_info; device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; device_create_info.pNext = nullptr; @@ -631,7 +638,7 @@ VulkanDeviceAPI::VulkanDeviceAPI() { device_create_info.ppEnabledLayerNames = nullptr; device_create_info.enabledExtensionCount = extensions.size(); device_create_info.ppEnabledExtensionNames = extensions.data(); - device_create_info.pEnabledFeatures = nullptr; + device_create_info.pEnabledFeatures = &enabled_features; VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device))); ctx.queue_mutex.reset(new std::mutex()); vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue)); @@ -882,10 +889,25 @@ class VulkanModuleNode final : public runtime::ModuleNode { } std::vector arg_binding; std::vector arg_template; + std::vector descriptor_set_pool_sizes; uint32_t num_pod = 0, num_buffer = 0; - auto push_arg_info = [&arg_binding, &arg_template](uint32_t binding, - VkDescriptorType desc_type) { + auto push_arg_info = [&arg_binding, &arg_template, &descriptor_set_pool_sizes]( + uint32_t binding, VkDescriptorType desc_type) { + { + auto result = + std::find_if(descriptor_set_pool_sizes.begin(), descriptor_set_pool_sizes.end(), + [&](const auto& psize) { return psize.type == desc_type; }); + if (result == descriptor_set_pool_sizes.end()) { + VkDescriptorPoolSize new_size; + new_size.type = desc_type; + new_size.descriptorCount = 1; + descriptor_set_pool_sizes.push_back(new_size); + } else { + result->descriptorCount++; + } + } + { VkDescriptorSetLayoutBinding bd; bd.binding = binding; @@ -941,22 +963,17 @@ class VulkanModuleNode final : public runtime::ModuleNode { &(pe->descriptor_set_layout))); } - { - VkDescriptorPoolSize pool_size; - pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; - pool_size.descriptorCount = arg_binding.size(); + if (!vctx.UseImmediate()) { VkDescriptorPoolCreateInfo descrip_pool_cinfo; descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; descrip_pool_cinfo.pNext = nullptr; descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT; descrip_pool_cinfo.maxSets = 1; - descrip_pool_cinfo.poolSizeCount = 1; - descrip_pool_cinfo.pPoolSizes = &pool_size; + descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size(); + descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data(); VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr, &(pe->descriptor_pool))); - } - if (!vctx.UseImmediate()) { VkDescriptorSetAllocateInfo alloc_info; alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; alloc_info.pNext = nullptr; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 5b26e9acf5a2..8188744ce687 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -43,6 +43,10 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: std::vector pod_args; uint32_t num_buffer = 0; + // Currently, all storage and uniform buffer arguments are passed as + // a single descriptor set at index 0. + const uint32_t descriptor_set = 0; + for (Var arg : f->params) { DataType t = arg.dtype(); if (t.is_handle()) { @@ -55,8 +59,8 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: // The loaded byte is cast to bool inside the LoadNode visitor below. value_storage_type = DataType::UInt(8); } - spirv::Value arg_value = - builder_->BufferArgument(builder_->GetSType(value_storage_type), 0, num_buffer); + spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), + descriptor_set, num_buffer); storage_info_[arg.get()].UpdateContentType(value_storage_type); var_map_[arg.get()] = arg_value; } else { @@ -87,7 +91,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: } else { shader.flag |= 1 << runtime::vulkan::ShaderMetaDataFlagMask::kUseUBO; // If we need to pass more arguments than push constants could handle, we use UBO. - spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, num_buffer); + spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, num_buffer); for (size_t i = 0; i < pod_args.size(); ++i) { spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast(i)); var_map_[pod_args[i].get()] = value; diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index cd48c93530ec..ce2b4bc15211 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -200,8 +200,7 @@ Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_); - this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set); - this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); + this->DecorateBufferArgument(val, descriptor_set, binding); return val; } @@ -253,12 +252,18 @@ Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint return this->MakeValue(spv::OpLoad, v_type, ptr); } -Value IRBuilder::DeclareUniformBuffer(const std::vector& value_types, uint32_t binding) { +Value IRBuilder::DeclareUniformBuffer(const std::vector& value_types, + uint32_t descriptor_set, uint32_t binding) { Value val = DeclareStorageVariable(value_types, spv::StorageClassUniform, kUniformPtr); - this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); + this->DecorateBufferArgument(val, descriptor_set, binding); return val; } +void IRBuilder::DecorateBufferArgument(Value val, uint32_t descriptor_set, uint32_t binding) { + this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set); + this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); +} + Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t index) { SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassUniform); Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const, diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 05a2bc631743..250d67067a81 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -470,7 +470,7 @@ class IRBuilder { * * \param arg_type The type of argument. * \param descriptor_set The descriptor set we want to use. - * \param binding The binding locaiton in descriptor set. + * \param binding The binding location in descriptor set. * \param The argument type. */ Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding); @@ -496,10 +496,12 @@ class IRBuilder { * * \note Only call this function once! * \param value_types The values in the uniform buffer - * \param binding The binding locaiton in descriptor set + * \param descriptor_set The descriptor set we want to use + * \param binding The binding location in descriptor set * \return reference to self. */ - Value DeclareUniformBuffer(const std::vector& value_types, uint32_t binding); + Value DeclareUniformBuffer(const std::vector& value_types, uint32_t descriptor_set, + uint32_t binding); /*! * \brief Get i-th uniform constant * \param v_type The value type @@ -585,6 +587,14 @@ class IRBuilder { Value DeclareStorageVariable(const std::vector& value_types, spv::StorageClass storage_class, ValueKind kind); + /*! + * \brief The common function to decorate storage buffer or uniform buffer arguments. + * \param val The Value to be decorated. + * \param descriptor_set The index of the descriptor set containing the buffer's descriptor + * \param binding The index of the buffer's descriptor within the descriptor set + */ + void DecorateBufferArgument(Value val, uint32_t descriptor_set, uint32_t binding); + // get constant given value encoded in uint64_t Value GetConst_(const SType& dtype, const uint64_t* pvalue); // declare type diff --git a/tests/python/unittest/test_target_codegen_spirv.py b/tests/python/unittest/test_target_codegen_spirv.py deleted file mode 100644 index b9f07cf426fe..000000000000 --- a/tests/python/unittest/test_target_codegen_spirv.py +++ /dev/null @@ -1,135 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -import tvm.testing -from tvm import te -from tvm import relay -from tvm.topi.math import cast -import numpy as np - - -def test_bool_load(): - def do_copy(A, B, n): - ib = tvm.tir.ir_builder.create() - A = ib.buffer_ptr(A) - B = ib.buffer_ptr(B) - - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - - max_threads = 32 - ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 1, max_threads)) - ib.scope_attr(tx, "thread_extent", max_threads) - tid = bx * max_threads + tx - - with ib.if_scope(tid < n): - B[tid] = cast(A[tid], "int32") - - return ib.get() - - n = 1024 - A = te.placeholder((n,), name="A", dtype="bool") - B = te.placeholder((n,), name="B", dtype="int32") - - target = "vulkan" - - if not tvm.testing.device_enabled(target): - return - - B = te.extern( - A.shape, - [A], - lambda ins, outs: do_copy(ins[0], outs[0], n), - name="bool_copy_ir", - dtype="int32", - ) - s = te.create_schedule(B.op) - - with tvm.transform.PassContext(opt_level=3): - func = tvm.build(s, [A, B], target) - - dev = tvm.device(target, 0) - a_np = np.random.uniform(size=n) > 0.5 - b_np = np.zeros((n,), dtype="int32") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - func(a, b) - ref = a_np.astype(np.int32) - tvm.testing.assert_allclose(b.asnumpy(), ref) - - -def check_mod(mod, x_np, res_np): - target = "vulkan" - dev = tvm.device(target, 0) - ex = relay.create_executor("vm", mod=mod, device=dev, target=target) - res = ex.evaluate()(x_np).asnumpy() - tvm.testing.assert_allclose(res, res_np, atol=1e-5) - - -def test_pushconstants(): - if not tvm.testing.device_enabled("vulkan"): - return - - # Three 32 bit pushconstants: any_dim, stride, stride - dtype = "float32" - x = relay.var("x", shape=(relay.Any(),), dtype=dtype) - mod = tvm.IRModule() - mod["main"] = relay.Function([x], relay.sqrt(x)) - x_np = np.random.uniform(size=(10,)).astype(dtype) - res_np = np.sqrt(x_np) - - check_mod(mod, x_np, res_np) - - # One 64 bit and one 32 bit constants - dtype = "int32" - x = relay.var("x", shape=(relay.Any(),), dtype=dtype) - mod = tvm.IRModule() - mod["main"] = relay.Function([x], relay.argsort(x)) - x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) - res_np = np.argsort(x_np) - - check_mod(mod, x_np, res_np) - - # One 64 bit and one 32 bit constants - dtype = "int32" - x = relay.var("x", shape=(relay.Any(),), dtype=dtype) - mod = tvm.IRModule() - mod["main"] = relay.Function([x], relay.cumsum(x)) - x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) - res_np = np.cumsum(x_np) - - check_mod(mod, x_np, res_np) - - -def test_unique(): - if not tvm.testing.device_enabled("vulkan"): - return - - dtype = "int32" - x = relay.var("x", shape=(relay.Any(),), dtype=dtype) - mod = tvm.IRModule() - [unique, _, num_unique] = relay.unique(x, is_sorted=True) - mod["main"] = relay.Function([x], relay.op.strided_slice(unique, begin=[0], end=num_unique)) - x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) - res_np = np.unique(x_np) - check_mod(mod, x_np, res_np) - - -if __name__ == "__main__": - test_bool_load() - test_pushconstants() - test_unique() diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index e68996df531f..9528741b6c52 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -14,12 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -import tvm.testing -from tvm import te + import re import numpy as np +import tvm +import tvm.testing +from tvm import relay, te +from tvm.topi.math import cast + + +def check_mod(mod, x_np, res_np): + target = "vulkan" + dev = tvm.device(target, 0) + ex = relay.create_executor("vm", mod=mod, device=dev, target=target) + res = ex.evaluate()(x_np).asnumpy() + tvm.testing.assert_allclose(res, res_np, atol=1e-5) + @tvm.testing.requires_vulkan def test_vector_comparison(): @@ -158,8 +169,150 @@ def build_f(f_ref): run_stress() +@tvm.testing.requires_vulkan +def test_vulkan_bool_load(): + def do_copy(A, B, n): + ib = tvm.tir.ir_builder.create() + A = ib.buffer_ptr(A) + B = ib.buffer_ptr(B) + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + + max_threads = 32 + ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 1, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < n): + B[tid] = cast(A[tid], "int32") + + return ib.get() + + n = 1024 + A = te.placeholder((n,), name="A", dtype="bool") + B = te.placeholder((n,), name="B", dtype="int32") + + target = "vulkan" + + B = te.extern( + A.shape, + [A], + lambda ins, outs: do_copy(ins[0], outs[0], n), + name="bool_copy_ir", + dtype="int32", + ) + s = te.create_schedule(B.op) + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [A, B], target) + + dev = tvm.device(target, 0) + a_np = np.random.uniform(size=n) > 0.5 + b_np = np.zeros((n,), dtype="int32") + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + func(a, b) + ref = a_np.astype(np.int32) + tvm.testing.assert_allclose(b.asnumpy(), ref) + + +@tvm.testing.requires_vulkan +def test_vulkan_pushconstants(): + # Three 32 bit pushconstants: any_dim, stride, stride + dtype = "float32" + x = relay.var("x", shape=(relay.Any(),), dtype=dtype) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], relay.sqrt(x)) + x_np = np.random.uniform(size=(10,)).astype(dtype) + res_np = np.sqrt(x_np) + + check_mod(mod, x_np, res_np) + + # One 64 bit and one 32 bit constants + dtype = "int32" + x = relay.var("x", shape=(relay.Any(),), dtype=dtype) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], relay.argsort(x)) + x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) + res_np = np.argsort(x_np) + + check_mod(mod, x_np, res_np) + + # One 64 bit and one 32 bit constants + dtype = "int32" + x = relay.var("x", shape=(relay.Any(),), dtype=dtype) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], relay.cumsum(x)) + x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) + res_np = np.cumsum(x_np) + + check_mod(mod, x_np, res_np) + + +@tvm.testing.requires_vulkan +def test_vulkan_unique(): + dtype = "int32" + x = relay.var("x", shape=(relay.Any(),), dtype=dtype) + mod = tvm.IRModule() + [unique, _, num_unique] = relay.unique(x, is_sorted=True) + mod["main"] = relay.Function([x], relay.op.strided_slice(unique, begin=[0], end=num_unique)) + x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) + res_np = np.unique(x_np) + check_mod(mod, x_np, res_np) + + +@tvm.testing.requires_vulkan +def test_vulkan_constant_passing(): + target = "vulkan" + + def test_scalar_params(num_int_params): + n = te.var("n") + scalars = [te.var("scale{}".format(i)) for i in range(num_int_params)] + scalar_sum = scalars[0] + for s in scalars[1:]: + scalar_sum += s + + A = te.placeholder((n,), name="A") + B = te.compute(A.shape, lambda i: scalar_sum + A[i], name="B") + + s = te.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=64) + s[B].bind(xo, bx) + s[B].bind(xi, tx) + f_add = tvm.build(s, scalars + [A, B], target) + + n = 1024 + scalars = [1 for _ in scalars] + dev = tvm.vulkan(0) + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev) + f_add(*scalars, a, b) + + tvm.testing.assert_allclose(a.asnumpy() + sum(scalars), b.asnumpy()) + + # f_add has 3+num_int_params scalar parameters. The other three + # are length_n, stride1, and stride2. + + # 4 params, 32 bytes. Within 128-byte spec-guaranteed size of + # push constants. Uses push constants. + test_scalar_params(1) + + # 24 params, 192 bytes. Too big for push constants, uses uniform + # buffer. + test_scalar_params(20) + + # 2047 params, 16376 bytes, just below 16kB of uniform buffer + # space guaranteed by the vulkan spec. + test_scalar_params(2044) + + if __name__ == "__main__": test_vector_comparison() test_vulkan_copy() test_vulkan_vectorize_add() test_vulkan_stress() + test_vulkan_constant_passing() + test_vulkan_bool_load() + test_vulkan_pushconstants() + test_vulkan_unique()