Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vulkan/vulkan.h>
#include <vulkan/vulkan_core.h>

#include <algorithm>
#include <array>
#include <cstring>

Expand Down Expand Up @@ -621,6 +622,12 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
}
return extensions;
}();

// All TVM-generated spirv shaders are marked as requiring int64
// support, so we need to request it from the device, too.
VkPhysicalDeviceFeatures enabled_features = {};
enabled_features.shaderInt64 = VK_TRUE;

VkDeviceCreateInfo device_create_info;
device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
device_create_info.pNext = nullptr;
Expand All @@ -631,7 +638,7 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
device_create_info.ppEnabledLayerNames = nullptr;
device_create_info.enabledExtensionCount = extensions.size();
device_create_info.ppEnabledExtensionNames = extensions.data();
device_create_info.pEnabledFeatures = nullptr;
device_create_info.pEnabledFeatures = &enabled_features;
VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device)));
ctx.queue_mutex.reset(new std::mutex());
vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
Expand Down Expand Up @@ -882,10 +889,25 @@ class VulkanModuleNode final : public runtime::ModuleNode {
}
std::vector<VkDescriptorSetLayoutBinding> arg_binding;
std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
std::vector<VkDescriptorPoolSize> descriptor_set_pool_sizes;
uint32_t num_pod = 0, num_buffer = 0;

auto push_arg_info = [&arg_binding, &arg_template](uint32_t binding,
VkDescriptorType desc_type) {
auto push_arg_info = [&arg_binding, &arg_template, &descriptor_set_pool_sizes](
uint32_t binding, VkDescriptorType desc_type) {
{
auto result =
std::find_if(descriptor_set_pool_sizes.begin(), descriptor_set_pool_sizes.end(),
[&](const auto& psize) { return psize.type == desc_type; });
if (result == descriptor_set_pool_sizes.end()) {
VkDescriptorPoolSize new_size;
new_size.type = desc_type;
new_size.descriptorCount = 1;
descriptor_set_pool_sizes.push_back(new_size);
} else {
result->descriptorCount++;
}
}

{
VkDescriptorSetLayoutBinding bd;
bd.binding = binding;
Expand Down Expand Up @@ -941,22 +963,17 @@ class VulkanModuleNode final : public runtime::ModuleNode {
&(pe->descriptor_set_layout)));
}

{
VkDescriptorPoolSize pool_size;
pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
pool_size.descriptorCount = arg_binding.size();
if (!vctx.UseImmediate()) {
VkDescriptorPoolCreateInfo descrip_pool_cinfo;
descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
descrip_pool_cinfo.pNext = nullptr;
descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
descrip_pool_cinfo.maxSets = 1;
descrip_pool_cinfo.poolSizeCount = 1;
descrip_pool_cinfo.pPoolSizes = &pool_size;
descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size();
descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data();
VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr,
&(pe->descriptor_pool)));
}

if (!vctx.UseImmediate()) {
VkDescriptorSetAllocateInfo alloc_info;
alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
alloc_info.pNext = nullptr;
Expand Down
10 changes: 7 additions & 3 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
std::vector<Var> pod_args;
uint32_t num_buffer = 0;

// Currently, all storage and uniform buffer arguments are passed as
// a single descriptor set at index 0.
const uint32_t descriptor_set = 0;

for (Var arg : f->params) {
DataType t = arg.dtype();
if (t.is_handle()) {
Expand All @@ -55,8 +59,8 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
// The loaded byte is cast to bool inside the LoadNode visitor below.
value_storage_type = DataType::UInt(8);
}
spirv::Value arg_value =
builder_->BufferArgument(builder_->GetSType(value_storage_type), 0, num_buffer);
spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type),
descriptor_set, num_buffer);
storage_info_[arg.get()].UpdateContentType(value_storage_type);
var_map_[arg.get()] = arg_value;
} else {
Expand Down Expand Up @@ -87,7 +91,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
} else {
shader.flag |= 1 << runtime::vulkan::ShaderMetaDataFlagMask::kUseUBO;
// If we need to pass more arguments than push constants could handle, we use UBO.
spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, num_buffer);
spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, num_buffer);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast<uint32_t>(i));
var_map_[pod_args[i].get()] = value;
Expand Down
13 changes: 9 additions & 4 deletions src/target/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set

ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_);

this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set);
this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
this->DecorateBufferArgument(val, descriptor_set, binding);
return val;
}

Expand Down Expand Up @@ -253,12 +252,18 @@ Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint
return this->MakeValue(spv::OpLoad, v_type, ptr);
}

Value IRBuilder::DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t binding) {
Value IRBuilder::DeclareUniformBuffer(const std::vector<SType>& value_types,
uint32_t descriptor_set, uint32_t binding) {
Value val = DeclareStorageVariable(value_types, spv::StorageClassUniform, kUniformPtr);
this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
this->DecorateBufferArgument(val, descriptor_set, binding);
return val;
}

void IRBuilder::DecorateBufferArgument(Value val, uint32_t descriptor_set, uint32_t binding) {
this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set);
this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
}

Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t index) {
SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassUniform);
Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const,
Expand Down
16 changes: 13 additions & 3 deletions src/target/spirv/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ class IRBuilder {
*
* \param arg_type The type of argument.
* \param descriptor_set The descriptor set we want to use.
* \param binding The binding locaiton in descriptor set.
* \param binding The binding location in descriptor set.
* \param The argument type.
*/
Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding);
Expand All @@ -496,10 +496,12 @@ class IRBuilder {
*
* \note Only call this function once!
* \param value_types The values in the uniform buffer
* \param binding The binding locaiton in descriptor set
* \param descriptor_set The descriptor set we want to use
* \param binding The binding location in descriptor set
* \return reference to self.
*/
Value DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t binding);
Value DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t descriptor_set,
uint32_t binding);
/*!
* \brief Get i-th uniform constant
* \param v_type The value type
Expand Down Expand Up @@ -585,6 +587,14 @@ class IRBuilder {
Value DeclareStorageVariable(const std::vector<SType>& value_types,
spv::StorageClass storage_class, ValueKind kind);

/*!
* \brief The common function to decorate storage buffer or uniform buffer arguments.
* \param val The Value to be decorated.
* \param descriptor_set The index of the descriptor set containing the buffer's descriptor
* \param binding The index of the buffer's descriptor within the descriptor set
*/
void DecorateBufferArgument(Value val, uint32_t descriptor_set, uint32_t binding);

// get constant given value encoded in uint64_t
Value GetConst_(const SType& dtype, const uint64_t* pvalue);
// declare type
Expand Down
135 changes: 0 additions & 135 deletions tests/python/unittest/test_target_codegen_spirv.py

This file was deleted.

Loading