From e829625ff72d3c3f84a4ace1af2a711373c3f0c9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 26 May 2021 16:05:16 -0700 Subject: [PATCH] [Vulkan][Refactor] Split out vulkan.cc into separate distinct functionality. This is in preparation for additional refactoring. Functions are organized according to group similar functionality together, to minimize the amount of file-to-file transfers needed later. The main divisions are between VulkanDeviceAPI, VulkanModuleNode/VulkanWrappedFunc, VulkanThreadEntry, and VulkanContext. Other than minimal renaming of private functions and addition of some comments, this commit should have zero changes to the functions definitions themselves, only to their arrangement within the src/runtime/vulkan directory. --- cmake/modules/Vulkan.cmake | 2 +- src/runtime/pack_args.h | 1 + src/runtime/vulkan/vulkan.cc | 1590 --------------------- src/runtime/vulkan/vulkan_buffer.cc | 47 + src/runtime/vulkan/vulkan_buffer.h | 63 + src/runtime/vulkan/vulkan_common.h | 66 +- src/runtime/vulkan/vulkan_context.cc | 183 +++ src/runtime/vulkan/vulkan_context.h | 95 ++ src/runtime/vulkan/vulkan_device_api.cc | 830 +++++++++++ src/runtime/vulkan/vulkan_device_api.h | 104 ++ src/runtime/vulkan/vulkan_module.cc | 73 + src/runtime/vulkan/vulkan_stream.cc | 159 +++ src/runtime/vulkan/vulkan_stream.h | 174 +-- src/runtime/vulkan/vulkan_thread_entry.cc | 84 ++ src/runtime/vulkan/vulkan_thread_entry.h | 67 + src/runtime/vulkan/vulkan_wrapped_func.cc | 412 ++++++ src/runtime/vulkan/vulkan_wrapped_func.h | 123 ++ 17 files changed, 2299 insertions(+), 1774 deletions(-) delete mode 100644 src/runtime/vulkan/vulkan.cc create mode 100644 src/runtime/vulkan/vulkan_buffer.cc create mode 100644 src/runtime/vulkan/vulkan_buffer.h create mode 100644 src/runtime/vulkan/vulkan_context.cc create mode 100644 src/runtime/vulkan/vulkan_context.h create mode 100644 src/runtime/vulkan/vulkan_device_api.cc create mode 100644 src/runtime/vulkan/vulkan_device_api.h create mode 100644 src/runtime/vulkan/vulkan_module.cc create mode 100644 src/runtime/vulkan/vulkan_stream.cc create mode 100644 src/runtime/vulkan/vulkan_thread_entry.cc create mode 100644 src/runtime/vulkan/vulkan_thread_entry.h create mode 100644 src/runtime/vulkan/vulkan_wrapped_func.cc create mode 100644 src/runtime/vulkan/vulkan_wrapped_func.h diff --git a/cmake/modules/Vulkan.cmake b/cmake/modules/Vulkan.cmake index 4dc9bd664d8a..3ee13aa38b98 100644 --- a/cmake/modules/Vulkan.cmake +++ b/cmake/modules/Vulkan.cmake @@ -24,7 +24,7 @@ if(USE_VULKAN) endif() include_directories(SYSTEM ${Vulkan_INCLUDE_DIRS}) message(STATUS "Build with Vulkan support") - file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/vulkan.cc) + file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc) file(GLOB COMPILER_VULKAN_SRCS src/target/spirv/*.cc) list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS}) list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS}) diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 7c852da77df6..3776d18fafcc 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -32,6 +32,7 @@ #define TVM_RUNTIME_PACK_ARGS_H_ #include +#include #include #include diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc deleted file mode 100644 index 8982ea32648b..000000000000 --- a/src/runtime/vulkan/vulkan.cc +++ /dev/null @@ -1,1590 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "../file_utils.h" -#include "../pack_args.h" -#include "../thread_storage_scope.h" -#include "../workspace_pool.h" -#include "vulkan_common.h" -#include "vulkan_module.h" -#include "vulkan_shader.h" -#include "vulkan_stream.h" - -namespace tvm { -namespace runtime { -namespace vulkan { - -/*! \brief Maximum number of GPU supported in VulkanModule. */ -static constexpr const int kVulkanMaxNumDevice = 8; - -/*! \brief TVM Vulkan binary pack magic number */ -static constexpr const int kVulkanModuleMagic = 0x02700027; - -struct VulkanBuffer { - VkBuffer buffer{VK_NULL_HANDLE}; - VkDeviceMemory memory{VK_NULL_HANDLE}; -}; - -/*! \brief A struct to represent Vulkan buffers backed by host visible memory */ -struct VulkanHostVisibleBuffer { - // A device where the buffer is allocated - VkDevice device{nullptr}; - // Vulkan buffer and memory - VulkanBuffer* vk_buf{nullptr}; - // The corresponding pointer to the host memory - void* host_addr{nullptr}; - // The size of the buffer in bytes - size_t size{0}; -}; - -using VulkanStagingBuffer = VulkanHostVisibleBuffer; -using VulkanUniformBuffer = VulkanHostVisibleBuffer; - -void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf) { - if (buf && buf->vk_buf) { - if (buf->host_addr != nullptr) { - vkUnmapMemory(buf->device, buf->vk_buf->memory); - } - if (buf->vk_buf->memory != VK_NULL_HANDLE) { - vkFreeMemory(buf->device, buf->vk_buf->memory, nullptr); - } - if (buf->vk_buf->buffer != VK_NULL_HANDLE) { - vkDestroyBuffer(buf->device, buf->vk_buf->buffer, nullptr); - } - buf->host_addr = nullptr; - delete buf->vk_buf; - } -} - -class VulkanThreadEntry { - public: - VulkanThreadEntry(); - static VulkanThreadEntry* ThreadLocal(); - - ~VulkanThreadEntry() { - // Because the thread entry refers to Device API - // The command buffer always will be destroyed before - // the instance and device get destroyed. - // The destruction need to be manually called - // to ensure the destruction order. - - pool.reset(); - streams_.clear(); - for (const auto& kv : staging_buffers_) { - DeleteHostVisibleBuffer(kv.second.get()); - } - } - - Device device; - std::unique_ptr pool; - VulkanStream* Stream(size_t device_id); - VulkanStagingBuffer* StagingBuffer(int device_id, size_t size); - void AllocateUniformBuffer(int device_id, size_t size); - VulkanUniformBuffer* GetUniformBuffer(int device_id, size_t size); - - private: - std::unordered_map> streams_; - std::unordered_map> staging_buffers_; - std::unordered_map> uniform_buffers_; -}; - -struct VulkanPipeline { - VulkanContext* vctx_{nullptr}; - VkShaderModule shader{VK_NULL_HANDLE}; - VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE}; - VkDescriptorPool descriptor_pool{VK_NULL_HANDLE}; - VkDescriptorSet descriptor_set{VK_NULL_HANDLE}; - VkPipelineLayout pipeline_layout{VK_NULL_HANDLE}; - VkPipeline pipeline{VK_NULL_HANDLE}; - VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE}; - bool use_ubo{false}; -}; - -typedef dmlc::ThreadLocalStore VulkanThreadStore; - -uint32_t FindMemoryType(const VulkanContext& vctx, VkBufferCreateInfo info, - VkMemoryPropertyFlags req_prop) { - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); - - VkMemoryRequirements mem_reqs; - vkGetBufferMemoryRequirements(vctx.device, buffer, &mem_reqs); - uint32_t type_bits = mem_reqs.memoryTypeBits; - VkPhysicalDeviceMemoryProperties phy_mem_prop; - vkGetPhysicalDeviceMemoryProperties(vctx.phy_device, &phy_mem_prop); - for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { - if ((type_bits & 1) == 1 && - (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { - return i; - } - type_bits >>= 1; - } - LOG(FATAL) << "Requested memory type not found"; - return 0; -} - -VkBufferCreateInfo MakeBufferCreateInfo(const VulkanContext& vctx, size_t nbytes, - VkBufferUsageFlags usage) { - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = nbytes; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(vctx.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - info.usage = usage; - return info; -} - -VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage, - uint32_t mem_type_index) { - auto info = MakeBufferCreateInfo(vctx, nbytes, usage); - // create buffer - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); - - // bind to memory - bool dedicated_allocation = false; - VkMemoryRequirements2KHR req2; - - if (vctx.get_buffer_memory_requirements_2_functions) { - VkBufferMemoryRequirementsInfo2KHR req_info2; - req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; - req_info2.pNext = 0; - req_info2.buffer = buffer; - - req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; - req2.pNext = 0; - - VkMemoryDedicatedRequirementsKHR dedicated_req; - dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; - dedicated_req.pNext = 0; - req2.pNext = &dedicated_req; - - vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - vctx.device, &req_info2, &req2); - dedicated_allocation = - dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; - } - - VkDeviceMemory memory; - if (!dedicated_allocation) { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = info.size; - minfo.memoryTypeIndex = mem_type_index; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } else { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = req2.memoryRequirements.size; - minfo.memoryTypeIndex = mem_type_index; - - VkMemoryDedicatedAllocateInfoKHR mdinfo; - mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; - mdinfo.pNext = 0; - mdinfo.image = 0; - mdinfo.buffer = buffer; - minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } - VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); - VulkanBuffer* pbuf = new VulkanBuffer(); - pbuf->memory = memory; - pbuf->buffer = buffer; - return pbuf; -} - -class VulkanDeviceAPI final : public DeviceAPI { - public: - VulkanDeviceAPI(); - ~VulkanDeviceAPI() { - for (auto& vctx : context_) { - vkDestroyDevice(vctx.device, nullptr); - } - if (instance_) { - vkDestroyInstance(instance_, nullptr); - } - } - void SetDevice(Device dev) final { VulkanThreadEntry::ThreadLocal()->device = dev; } - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; - std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); - void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { - if (nbytes == 0) { - // Vulkan seems to have issues if we return nullptr on zero size alloc - nbytes = 1; - } - const auto& vctx = context(dev.device_id); - auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | - VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - return CreateBuffer(vctx, nbytes, usage, vctx.compute_mtype_index); - } - - void FreeDataSpace(Device dev, void* ptr) final { - // Before releasing the vkBuffer, call sync to - // finish all the vulkan commands that reference the buffer. - StreamSync(dev, nullptr); - - const auto& vctx = context(dev.device_id); - auto* pbuf = static_cast(ptr); - vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr); - vkFreeMemory(vctx.device, pbuf->memory, nullptr); - delete pbuf; - } - - Target GetDeviceDescription(VkInstance instance, VkPhysicalDevice dev, - const std::vector& instance_extensions, - const std::vector& device_extensions); - - protected: - void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, - Device dev_from, Device dev_to, DLDataType type_hint, - TVMStreamHandle stream) final { - ICHECK(stream == nullptr); - Device dev = dev_from; - if (dev_from.device_type == kDLCPU) { - dev = dev_to; - } - - int from_dev_type = static_cast(dev_from.device_type); - int to_dev_type = static_cast(dev_to.device_type); - if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) { - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_from.device_id) - ->Launch([=](VulkanStreamState* state) { - // 1: copy - const auto* from_buf = static_cast(from); - auto* to_buf = static_cast(to); - VkBufferCopy copy_info; - copy_info.srcOffset = from_offset; - copy_info.dstOffset = to_offset; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info); - // 2: barrier(transfer-> compute|transfer) - ICHECK_EQ(dev_from.device_id, dev_to.device_id) << "Vulkan disallow cross device copy."; - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; - barrier_info.dstAccessMask = - (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | - VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); - vkCmdPipelineBarrier( - state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1, - &barrier_info, 0, nullptr, 0, nullptr); - }); - - } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) { - const auto* from_buf = static_cast(from); - const auto& vctx = context(dev_from.device_id); - auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_from.device_id, size); - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_from.device_id) - ->Launch([&](VulkanStreamState* state) { - VkBufferCopy copy_info; - copy_info.srcOffset = from_offset; - copy_info.dstOffset = 0; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->vk_buf->buffer, 1, - ©_info); - }); - VulkanThreadEntry::ThreadLocal()->Stream(dev_from.device_id)->Synchronize(); - if (!vctx.coherent_staging) { - VkMappedMemoryRange mrange; - mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; - mrange.pNext = nullptr; - mrange.memory = temp->vk_buf->memory; - mrange.offset = 0; - mrange.size = VK_WHOLE_SIZE; // size; - VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange)); - } - memcpy(static_cast(to) + to_offset, static_cast(temp->host_addr), size); - } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) { - const auto& vctx = context(dev_to.device_id); - const auto* to_buf = static_cast(to); - VulkanStagingBuffer* temp = - VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_to.device_id, size); - memcpy(temp->host_addr, static_cast(from) + from_offset, size); - // host side flush if access is not coherent. - // so writes from CPU is visible to GPU - if (!vctx.coherent_staging) { - VkMappedMemoryRange mrange; - mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; - mrange.pNext = nullptr; - mrange.memory = temp->vk_buf->memory; - mrange.offset = 0; - mrange.size = VK_WHOLE_SIZE; // size; - VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange)); - } - - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_to.device_id) - ->Launch([&](VulkanStreamState* state) { - // 0: barrier(host->transfer) - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = 0; - barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; - vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0, - nullptr); - // 1: copy - VkBufferCopy copy_info; - copy_info.srcOffset = 0; - copy_info.dstOffset = to_offset; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, temp->vk_buf->buffer, to_buf->buffer, 1, - ©_info); - }); - // TODO(tulloch): should we instead make the staging buffer a property of the - // Stream? This would allow us to elide synchronizations here. - VulkanThreadEntry::ThreadLocal()->Stream(dev_to.device_id)->Synchronize(); - } else { - LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan" - << ", from=" << from_dev_type << ", to=" << to_dev_type; - } - } - - public: - // Current vulkan implementation has one "stream" per CPU thread, - // with all commands writing into a single command buffer that is - // submitted on a call to StreamSync. Therefore, for now, these are - // mostly no-ops. If needed in the future, could have multiple - // command buffers to act as multiple streams. - TVMStreamHandle CreateStream(Device dev) final { return nullptr; } - - void FreeStream(Device dev, TVMStreamHandle stream) final { - ICHECK_EQ(stream, static_cast(nullptr)); - return; - } - - // Syncing two streams is a nop, since there is only one stream. - void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) final { - ICHECK_EQ(event_src, static_cast(nullptr)); - ICHECK_EQ(event_dst, static_cast(nullptr)); - return; - } - - void StreamSync(Device dev, TVMStreamHandle stream) final { - ICHECK_EQ(stream, static_cast(nullptr)); - VulkanThreadEntry::ThreadLocal()->Stream(dev.device_id)->Synchronize(); - } - - void SetStream(Device dev, TVMStreamHandle stream) final { - ICHECK_EQ(stream, static_cast(nullptr)); - return; - } - - void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final { - return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(dev, size); - } - - void FreeWorkspace(Device dev, void* data) final { - VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(dev, data); - } - - static VulkanDeviceAPI* Global() { - // Most of the TVM Global() functions allocate with "new" and do - // not deallocate, as the OS can clean up any leftover buffers at - // the end. In this case, we need the VulkanDeviceAPI destructor - // to call vkDestroyInstance, to prevent a segfault on exit when - // using some nvidia drivers. - static VulkanDeviceAPI inst; - return &inst; - } - - const VulkanContext& context(size_t device_id) const { - ICHECK_LT(device_id, context_.size()); - return context_[device_id]; - } - - Target GenerateTarget(size_t device_id) const { return context(device_id).target; } - - private: - std::vector find_enabled_extensions( - const std::vector& ext_prop, - const std::vector& required_extensions, - const std::vector& optional_extensions) { - std::set available_extensions; - for (const auto& prop : ext_prop) { - if (prop.specVersion > 0) { - available_extensions.insert(prop.extensionName); - } - } - - std::vector enabled_extensions; - for (const auto& ext : required_extensions) { - ICHECK(available_extensions.count(ext)) - << "Required vulkan extension \"" << ext << "\" not supported by driver"; - enabled_extensions.push_back(ext); - } - - for (const auto& ext : optional_extensions) { - if (available_extensions.count(ext)) { - enabled_extensions.push_back(ext); - } - } - - return enabled_extensions; - } - - VkInstance instance_{nullptr}; - // The physical devices, have 1 to 1 mapping to devices - std::vector context_; -}; - -Target VulkanDeviceAPI::GetDeviceDescription(VkInstance instance, VkPhysicalDevice dev, - const std::vector& instance_extensions, - const std::vector& device_extensions) { - auto has_extension = [&](const char* query) { - return std::any_of(device_extensions.begin(), device_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }) || - std::any_of(instance_extensions.begin(), instance_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; - - // Declare output locations for properties - VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2}; - VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES}; - VkPhysicalDeviceSubgroupProperties subgroup = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES}; - - // Need to do initial query in order to check the apiVersion. - vkGetPhysicalDeviceProperties(dev, &properties.properties); - - // Set up linked list for property query - { - void** pp_next = &properties.pNext; - if (has_extension("VK_KHR_driver_properties")) { - *pp_next = &driver; - pp_next = &driver.pNext; - } - if (properties.properties.apiVersion >= VK_API_VERSION_1_1) { - *pp_next = &subgroup; - pp_next = &subgroup.pNext; - } - } - - // Declare output locations for features - VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; - VkPhysicalDevice8BitStorageFeatures storage_8bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; - VkPhysicalDevice16BitStorageFeatures storage_16bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; - VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; - - // Set up linked list for feature query - { - void** pp_next = &features.pNext; - if (has_extension("VK_KHR_8bit_storage")) { - *pp_next = &storage_8bit; - pp_next = &storage_8bit.pNext; - } - if (has_extension("VK_KHR_16bit_storage")) { - *pp_next = &storage_16bit; - pp_next = &storage_16bit.pNext; - } - if (has_extension("VK_KHR_shader_float16_int8")) { - *pp_next = &float16_int8; - pp_next = &float16_int8.pNext; - } - } - - if (has_extension("VK_KHR_get_physical_device_properties2")) { - // Preferred method, call to get all properties that can be queried. - auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( - vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); - vkGetPhysicalDeviceProperties2KHR(dev, &properties); - - auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( - vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); - vkGetPhysicalDeviceFeatures2KHR(dev, &features); - } else { - // Fallback, get as many features as we can from the Vulkan1.0 - // API. Corresponding vkGetPhysicalDeviceProperties was already done earlier. - vkGetPhysicalDeviceFeatures(dev, &features.features); - } - - //// Now, extracting all the information from the vulkan query. - - // Not technically needed, because VK_SHADER_STAGE_COMPUTE_BIT will - // be set so long at least one queue has VK_QUEUE_COMPUTE_BIT, but - // preferring the explicit check. - uint32_t supported_subgroup_operations = - (subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0; - - // Even if we can't query it, warp size must be at least 1. Must - // also be defined, as `transpose` operation requires it. - uint32_t thread_warp_size = std::max(subgroup.subgroupSize, 1U); - - // By default, use the maximum API version that the driver allows, - // so that any supported features can be used by TVM shaders. - // However, if we can query the conformance version, then limit to - // only using the api version that passes the vulkan conformance - // tests. - uint32_t vulkan_api_version = properties.properties.apiVersion; - if (has_extension("VK_KHR_driver_properties")) { - auto api_major = VK_VERSION_MAJOR(vulkan_api_version); - auto api_minor = VK_VERSION_MINOR(vulkan_api_version); - if ((api_major > driver.conformanceVersion.major) || - ((api_major == driver.conformanceVersion.major) && - (api_minor > driver.conformanceVersion.minor))) { - vulkan_api_version = - VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0); - } - } - - // From "Versions and Formats" section of Vulkan spec. - uint32_t max_spirv_version = 0x10000; - if (vulkan_api_version >= VK_API_VERSION_1_2) { - max_spirv_version = 0x10500; - } else if (has_extension("VK_KHR_spirv_1_4")) { - max_spirv_version = 0x10400; - } else if (vulkan_api_version >= VK_API_VERSION_1_1) { - max_spirv_version = 0x10300; - } - - // Support is available based on these extensions, but allow it to - // be disabled based on an environment variable. - bool supports_push_descriptor = - has_extension("VK_KHR_push_descriptor") && has_extension("VK_KHR_descriptor_update_template"); - { - const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR"); - if (disable && *disable) { - supports_push_descriptor = false; - } - } - - // Support is available based on these extensions, but allow it to - // be disabled based on an environment variable. - bool supports_dedicated_allocation = has_extension("VK_KHR_get_memory_requirements2") && - has_extension("VK_KHR_dedicated_allocation"); - { - const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); - if (disable && *disable) { - supports_dedicated_allocation = false; - } - } - - Map config = { - {"kind", String("vulkan")}, - // Feature support - {"supports_float16", Bool(float16_int8.shaderFloat16)}, - {"supports_float32", Bool(true)}, - {"supports_float64", Bool(features.features.shaderFloat64)}, - {"supports_int8", Bool(float16_int8.shaderInt8)}, - {"supports_int16", Bool(features.features.shaderInt16)}, - {"supports_int32", Bool(true)}, - {"supports_int64", Bool(features.features.shaderInt64)}, - {"supports_8bit_buffer", Bool(storage_8bit.storageBuffer8BitAccess)}, - {"supports_16bit_buffer", Bool(storage_16bit.storageBuffer16BitAccess)}, - {"supports_storage_buffer_storage_class", - Bool(has_extension("VK_KHR_storage_buffer_storage_class"))}, - {"supports_push_descriptor", Bool(supports_push_descriptor)}, - {"supports_dedicated_allocation", Bool(supports_dedicated_allocation)}, - {"supported_subgroup_operations", Integer(supported_subgroup_operations)}, - // Physical device limits - {"max_num_threads", Integer(properties.properties.limits.maxComputeWorkGroupInvocations)}, - {"thread_warp_size", Integer(thread_warp_size)}, - {"max_block_size_x", Integer(properties.properties.limits.maxComputeWorkGroupSize[0])}, - {"max_block_size_y", Integer(properties.properties.limits.maxComputeWorkGroupSize[1])}, - {"max_block_size_z", Integer(properties.properties.limits.maxComputeWorkGroupSize[2])}, - {"max_push_constants_size", Integer(properties.properties.limits.maxPushConstantsSize)}, - {"max_uniform_buffer_range", Integer(properties.properties.limits.maxUniformBufferRange)}, - {"max_storage_buffer_range", - Integer(IntImm(DataType::UInt(32), properties.properties.limits.maxStorageBufferRange))}, - {"max_per_stage_descriptor_storage_buffer", - Integer(properties.properties.limits.maxPerStageDescriptorStorageBuffers)}, - {"max_shared_memory_per_block", - Integer(properties.properties.limits.maxComputeSharedMemorySize)}, - // Other device properties - {"device_name", String(properties.properties.deviceName)}, - {"driver_version", Integer(properties.properties.driverVersion)}, - {"vulkan_api_version", Integer(vulkan_api_version)}, - {"max_spirv_version", Integer(max_spirv_version)}, - }; - - return Target(config); -} - -void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { - size_t index = static_cast(dev.device_id); - if (kind == kExist) { - *rv = static_cast(index < context_.size()); - return; - } - ICHECK_LT(index, context_.size()) << "Invalid device id " << index; - - const auto& target = context(index).target; - - switch (kind) { - case kMaxThreadsPerBlock: { - *rv = target->GetAttr("max_num_threads").value(); - break; - } - case kMaxSharedMemoryPerBlock: { - *rv = target->GetAttr("max_shared_memory_per_block"); - break; - } - case kWarpSize: { - *rv = target->GetAttr("thread_warp_size").value(); - break; - } - case kComputeVersion: { - int64_t value = target->GetAttr("vulkan_api_version").value(); - std::ostringstream os; - os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." - << VK_VERSION_PATCH(value); - *rv = os.str(); - break; - } - case kDeviceName: - *rv = target->GetAttr("device_name").value(); - break; - - case kMaxClockRate: - break; - - case kMultiProcessorCount: - break; - - case kExist: - break; - - case kMaxThreadDimensions: { - std::stringstream ss; // use json string to return multiple int values; - ss << "[" << target->GetAttr("max_block_size_x").value() << ", " - << target->GetAttr("max_block_size_y").value() << ", " - << target->GetAttr("max_block_size_z").value() << "]"; - *rv = ss.str(); - break; - } - - case kMaxRegistersPerBlock: - break; - - case kGcnArch: - break; - - case kApiVersion: - *rv = VK_HEADER_VERSION; - break; - - case kDriverVersion: { - int64_t value = target->GetAttr("driver_version").value(); - std::ostringstream os; - os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." - << VK_VERSION_PATCH(value); - *rv = os.str(); - break; - } - } -} - -VulkanDeviceAPI::VulkanDeviceAPI() { - const auto layers = []() -> std::vector { - uint32_t inst_layer_prop_count; - VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr)); - std::vector inst_layer_prop(inst_layer_prop_count); - VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data())); - std::vector l; - - const char* enable = std::getenv("TVM_VULKAN_ENABLE_VALIDATION_LAYERS"); - bool validation_enabled = enable && *enable; - if (validation_enabled) { - for (const auto& lp : inst_layer_prop) { - if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) { - l.push_back("VK_LAYER_LUNARG_standard_validation"); - } - if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) { - l.push_back("VK_LAYER_LUNARG_parameter_validation"); - } - if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) { - l.push_back("VK_LAYER_KHRONOS_validation"); - } - } - } - return l; - }(); - - const auto instance_extensions = [this]() { - std::vector required_extensions{}; - std::vector optional_extensions{"VK_KHR_get_physical_device_properties2"}; - - uint32_t inst_extension_prop_count; - VULKAN_CALL( - vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr)); - std::vector inst_extension_prop(inst_extension_prop_count); - VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, - inst_extension_prop.data())); - - return find_enabled_extensions(inst_extension_prop, required_extensions, optional_extensions); - }(); - - auto has_instance_extension = [&instance_extensions](const char* query) { - return std::any_of(instance_extensions.begin(), instance_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; - - const auto instance_api_version = []() { - uint32_t api_version = VK_MAKE_VERSION(1, 0, 0); - - // Result from vkGetInstanceProcAddr is NULL if driver only - // supports vulkan 1.0. - auto vkEnumerateInstanceVersion = - (PFN_vkEnumerateInstanceVersion)vkGetInstanceProcAddr(NULL, "vkEnumerateInstanceVersion"); - if (vkEnumerateInstanceVersion) { - vkEnumerateInstanceVersion(&api_version); - } - - return api_version; - }(); - - { - VkApplicationInfo app_info; - app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - app_info.pNext = nullptr; - app_info.pApplicationName = "TVM"; - app_info.applicationVersion = 0; - app_info.pEngineName = ""; - app_info.engineVersion = 0; - app_info.apiVersion = instance_api_version; - - VkInstanceCreateInfo inst_info; - inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - inst_info.pNext = nullptr; - inst_info.flags = 0; - inst_info.pApplicationInfo = &app_info; - inst_info.enabledLayerCount = layers.size(); - inst_info.ppEnabledLayerNames = layers.data(); - inst_info.enabledExtensionCount = instance_extensions.size(); - inst_info.ppEnabledExtensionNames = instance_extensions.data(); - - VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_)); - } - - uint32_t phy_dev_count = 0; - VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr)); - std::vector all_phy_devs(phy_dev_count); - VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs))); - for (VkPhysicalDevice phy_dev : all_phy_devs) { - // Get a list of queue families supporting compute, in order of preference. We currently only - // make use of the most preferred one family. - std::vector queue_family_indexes = GetComputeQueueFamilies(phy_dev); - if (queue_family_indexes.empty()) continue; - uint32_t queue_family_index = queue_family_indexes[0]; - float priority = 1.0f; - - struct VkDeviceQueueCreateInfo queue_create_info; - queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - queue_create_info.pNext = nullptr; - queue_create_info.flags = 0; - queue_create_info.queueFamilyIndex = queue_family_index; - queue_create_info.queueCount = 1; - queue_create_info.pQueuePriorities = &priority; - - VulkanContext ctx; - // setup context - ctx.phy_device = phy_dev; - vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop)); - - const auto device_extensions = [&]() { - std::vector required_extensions{}; - std::vector optional_extensions{ - "VK_KHR_driver_properties", - "VK_KHR_storage_buffer_storage_class", - "VK_KHR_8bit_storage", - "VK_KHR_16bit_storage", - "VK_KHR_shader_float16_int8", - "VK_KHR_push_descriptor", - "VK_KHR_descriptor_update_template", - "VK_KHR_get_memory_requirements2", - "VK_KHR_dedicated_allocation", - "VK_KHR_spirv_1_4", - }; - - uint32_t device_extension_prop_count; - VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr, - &device_extension_prop_count, nullptr)); - std::vector device_extension_prop(device_extension_prop_count); - VULKAN_CALL(vkEnumerateDeviceExtensionProperties( - ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data())); - - return find_enabled_extensions(device_extension_prop, required_extensions, - optional_extensions); - }(); - - ctx.target = GetDeviceDescription(instance_, phy_dev, instance_extensions, device_extensions); - - { - // Enable all features we may use that a device supports. - VkPhysicalDeviceFeatures2 enabled_features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; - VkPhysicalDevice8BitStorageFeatures storage_8bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; - VkPhysicalDevice16BitStorageFeatures storage_16bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; - VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; - - void** pp_next = &enabled_features.pNext; - bool needs_float16_int8 = false; - - auto has_support = [&](const char* name) { return ctx.target->GetAttr(name).value(); }; - if (has_support("supports_float16")) { - float16_int8.shaderFloat16 = true; - needs_float16_int8 = true; - } - if (has_support("supports_float64")) { - enabled_features.features.shaderFloat64 = true; - } - if (has_support("supports_int8")) { - float16_int8.shaderInt8 = true; - needs_float16_int8 = true; - } - if (has_support("supports_int16")) { - enabled_features.features.shaderInt16 = true; - } - if (has_support("supports_int64")) { - enabled_features.features.shaderInt64 = true; - } - if (has_support("supports_8bit_buffer")) { - storage_8bit.storageBuffer8BitAccess = true; - *pp_next = &storage_8bit; - pp_next = &storage_8bit.pNext; - } - if (has_support("supports_16bit_buffer")) { - storage_16bit.storageBuffer16BitAccess = true; - *pp_next = &storage_16bit; - pp_next = &storage_16bit.pNext; - } - - if (needs_float16_int8) { - *pp_next = &float16_int8; - pp_next = &float16_int8.pNext; - } - - VkDeviceCreateInfo device_create_info; - device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; - device_create_info.pNext = nullptr; - device_create_info.flags = 0; - device_create_info.queueCreateInfoCount = 1; - device_create_info.pQueueCreateInfos = &queue_create_info; - device_create_info.enabledLayerCount = 0; - device_create_info.ppEnabledLayerNames = nullptr; - device_create_info.enabledExtensionCount = device_extensions.size(); - device_create_info.ppEnabledExtensionNames = device_extensions.data(); - - if (has_instance_extension("VK_KHR_get_physical_device_properties2")) { - device_create_info.pEnabledFeatures = nullptr; - device_create_info.pNext = &enabled_features; - } else { - device_create_info.pNext = nullptr; - device_create_info.pEnabledFeatures = &enabled_features.features; - } - VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device))); - } - - ctx.queue_mutex.reset(new std::mutex()); - vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue)); - ctx.queue_family_index = queue_family_index; - // Find suitable memory type for staging and compute - // Find suitable compute index. - VkBuffer buffer; - VkMemoryRequirements req_staging, req_compute; - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = 1024; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(ctx.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - - // get staging requirement - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer)); - vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging); - vkDestroyBuffer(ctx.device, buffer, nullptr); - // get compute requirement - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | - VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer)); - vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute); - vkDestroyBuffer(ctx.device, buffer, nullptr); - - // Query phyiscal device property - // find a memory that is host visible, no need to be consistent - int win_rank = -1; - VkPhysicalDeviceMemoryProperties prop; - vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop); - - for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { - VkMemoryType ty = prop.memoryTypes[k]; - size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; - // host visible - if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue; - // match copy requirment - if (!(req_staging.memoryTypeBits & (1 << k))) continue; - if (heap_size < 1024) continue; - int rank = 0; - rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT; - if (rank > win_rank) { - win_rank = rank; - ctx.staging_mtype_index = k; - ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - } - } - ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; - - win_rank = -1; - for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { - VkMemoryType ty = prop.memoryTypes[k]; - size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; - // host visible - if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue; - // match copy requirment - if (!(req_staging.memoryTypeBits & (1 << k))) continue; - if (heap_size < 1024) continue; - int rank = 0; - // prefer not host visible - rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); - if (rank > win_rank) { - win_rank = rank; - ctx.compute_mtype_index = k; - } - } - ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; - - if (ctx.target->GetAttr("supports_push_descriptor").value()) { - ctx.descriptor_template_khr_functions = - std::make_unique(ctx.device); - } - - if (ctx.target->GetAttr("supports_dedicated_allocation").value()) { - ctx.get_buffer_memory_requirements_2_functions = - std::make_unique(ctx.device); - } - - context_.push_back(std::move(ctx)); - } - - LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices.."; - for (size_t i = 0; i < context_.size(); ++i) { - LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName - << "\' phy_dev_id=" << context_[i].phy_device - << " use_immediate=" << context_[i].UseImmediate(); - } -} - -std::vector VulkanDeviceAPI::GetComputeQueueFamilies(VkPhysicalDevice phy_dev) { - uint32_t queue_prop_count = 0; - vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr); - std::vector queue_props(queue_prop_count); - vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props)); - - std::vector result; - // Prefer compute-only queues. On cerain devices supporting this (e.g. Mesa RADV), using - // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop). - for (uint32_t i = 0; i != queue_prop_count; ++i) { - if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && - (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) { - result.push_back(i); - } - } - // Now, push the compute queues that we skipped above into the list. - for (uint32_t i = 0; i != queue_prop_count; ++i) { - if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && - (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) { - result.push_back(i); - } - } - return result; -} - -// namespace vulkan -class VulkanModuleNode; - -// a wrapped function class to get packed func. -class VulkanWrappedFunc { - public: - void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags) { - m_ = m; - sptr_ = sptr; - func_name_ = func_name; - num_buffer_args_ = num_buffer_args; - num_pack_args_ = num_pack_args; - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); - } - - void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; - - private: - // internal module - VulkanModuleNode* m_; - // the resource holder - ObjectPtr sptr_; - // v The name of the function. - std::string func_name_; - // Number of buffer arguments - size_t num_buffer_args_; - // number of packed arguments. - size_t num_pack_args_; - // Device state cache per device. - // mark as mutable, to enable lazy initialization - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; - - mutable std::array, kVulkanMaxNumDevice> scache_; -}; - -// Multi-device enabled module. -class VulkanModuleNode final : public runtime::ModuleNode { - public: - explicit VulkanModuleNode(std::unordered_map smap, - std::unordered_map fmap, std::string source) - : smap_(smap), fmap_(fmap), source_(source) {} - - const char* type_key() const final { return "vulkan"; } - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; - auto it = fmap_.find(name); - if (it == fmap_.end()) return PackedFunc(); - const FunctionInfo& info = it->second; - VulkanWrappedFunc f; - size_t num_buffer_args = NumBufferArgs(info.arg_types); - f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, - info.thread_axis_tags); - return PackFuncNonBufferArg(std::move(f), info.arg_types); - } - - ~VulkanModuleNode() { - // cleanup vulkan related caches. - for (size_t device_id = 0; device_id < ecache_.size(); ++device_id) { - for (auto& kv : ecache_[device_id]) { - auto& pe = kv.second; - ICHECK(pe); - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - - if (pe->descriptor_update_template != VK_NULL_HANDLE) { - vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR( - vctx.device, pe->descriptor_update_template, nullptr); - } - vkDestroyPipeline(vctx.device, pe->pipeline, nullptr); - vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr); - vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr); - vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr); - vkDestroyShaderModule(vctx.device, pe->shader, nullptr); - } - } - } - - std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, - size_t num_pack_args) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - std::lock_guard lock(mutex_); - const auto& cp = ecache_[device_id][func_name]; - if (cp) { - return cp; - } - // Create new pipeline - auto pe = std::make_shared(); - { - // create shader - auto sit = smap_.find(func_name); - ICHECK(sit != smap_.end()); - pe->use_ubo = sit->second.flag & (1 << ShaderMetaDataFlagMask::kUseUBO); - const std::vector& data = sit->second.data; - VkShaderModuleCreateInfo shader_cinfo; - shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; - shader_cinfo.pNext = nullptr; - shader_cinfo.flags = 0; - shader_cinfo.codeSize = data.size() * sizeof(uint32_t); - shader_cinfo.pCode = data.data(); - VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader))); - } - std::vector arg_binding; - std::vector arg_template; - std::vector descriptor_set_pool_sizes; - uint32_t num_pod = 0, num_buffer = 0; - - auto push_arg_info = [&arg_binding, &arg_template, &descriptor_set_pool_sizes]( - uint32_t binding, VkDescriptorType desc_type) { - { - auto result = - std::find_if(descriptor_set_pool_sizes.begin(), descriptor_set_pool_sizes.end(), - [&](const auto& psize) { return psize.type == desc_type; }); - if (result == descriptor_set_pool_sizes.end()) { - VkDescriptorPoolSize new_size; - new_size.type = desc_type; - new_size.descriptorCount = 1; - descriptor_set_pool_sizes.push_back(new_size); - } else { - result->descriptorCount++; - } - } - - { - VkDescriptorSetLayoutBinding bd; - bd.binding = binding; - bd.descriptorType = desc_type; - bd.descriptorCount = 1; - bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - bd.pImmutableSamplers = nullptr; - arg_binding.push_back(bd); - } - { - VkDescriptorUpdateTemplateEntryKHR tpl; - tpl.dstBinding = binding; - tpl.dstArrayElement = 0; - tpl.descriptorCount = 1; - tpl.descriptorType = desc_type; - tpl.offset = binding * sizeof(VkDescriptorBufferInfo); - tpl.stride = sizeof(VkDescriptorBufferInfo); - arg_template.push_back(tpl); - } - }; - - { - auto fit = fmap_.find(func_name); - ICHECK(fit != fmap_.end()); - for (DLDataType arg_type : fit->second.arg_types) { - if (arg_type.code == kTVMOpaqueHandle) { - push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); - ++num_buffer; - } else { - ++num_pod; - } - } - } - - size_t nbytes_scalars = num_pod * sizeof(ArgUnion64); - if (pe->use_ubo) { - // Use UBO instead of push constants - push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); - VulkanThreadEntry::ThreadLocal()->AllocateUniformBuffer(device_id, nbytes_scalars); - } - - { - VkDescriptorSetLayoutCreateInfo descrip_cinfo; - descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; - descrip_cinfo.pNext = nullptr; - descrip_cinfo.flags = 0; - if (vctx.UseImmediate()) { - descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; - } - descrip_cinfo.bindingCount = arg_binding.size(); - descrip_cinfo.pBindings = arg_binding.data(); - VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr, - &(pe->descriptor_set_layout))); - } - - if (!vctx.UseImmediate()) { - VkDescriptorPoolCreateInfo descrip_pool_cinfo; - descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; - descrip_pool_cinfo.pNext = nullptr; - descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT; - descrip_pool_cinfo.maxSets = 1; - descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size(); - descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data(); - VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr, - &(pe->descriptor_pool))); - - VkDescriptorSetAllocateInfo alloc_info; - alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; - alloc_info.pNext = nullptr; - alloc_info.descriptorPool = pe->descriptor_pool; - alloc_info.descriptorSetCount = 1; - alloc_info.pSetLayouts = &(pe->descriptor_set_layout); - VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set))); - } - - VkPushConstantRange crange; - crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - crange.offset = 0; - crange.size = sizeof(ArgUnion64) * num_pack_args; - - VkPipelineLayoutCreateInfo playout_cinfo; - playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; - playout_cinfo.pNext = nullptr; - playout_cinfo.flags = 0; - playout_cinfo.setLayoutCount = 1; - playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout); - - if (0 < nbytes_scalars && !pe->use_ubo) { - playout_cinfo.pushConstantRangeCount = 1; - playout_cinfo.pPushConstantRanges = &crange; - ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize); - } else { - playout_cinfo.pushConstantRangeCount = 0; - playout_cinfo.pPushConstantRanges = nullptr; - } - - VULKAN_CALL( - vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout))); - - VkComputePipelineCreateInfo pipeline_cinfo; - pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; - pipeline_cinfo.pNext = nullptr; - pipeline_cinfo.flags = 0; - pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; - pipeline_cinfo.stage.pNext = nullptr; - pipeline_cinfo.stage.flags = 0; - pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; - pipeline_cinfo.stage.module = pe->shader; - pipeline_cinfo.stage.pName = func_name.c_str(); - pipeline_cinfo.stage.pSpecializationInfo = nullptr; - pipeline_cinfo.layout = pe->pipeline_layout; - pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE; - pipeline_cinfo.basePipelineIndex = 0; - VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, - &(pe->pipeline))); - - if (vctx.UseImmediate()) { - VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo; - descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR; - descrip_template_cinfo.pNext = 0; - descrip_template_cinfo.flags = 0; - descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size(); - descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data(); - descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR; - descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout; - descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE; - descrip_template_cinfo.pipelineLayout = pe->pipeline_layout; - descrip_template_cinfo.set = 0; - VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR( - vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template))); - } - ecache_[device_id][func_name] = pe; - return pe; - } - - void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = GetFileFormat(file_name, format); - ICHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; - std::string meta_file = GetMetaFilePath(file_name); - SaveMetaDataToFile(meta_file, fmap_); - std::string data_bin; - dmlc::MemoryStringStream fs(&data_bin); - dmlc::Stream* stream = &fs; - uint32_t magic = kVulkanModuleMagic; - stream->Write(magic); - stream->Write(smap_); - SaveBinaryToFile(file_name, data_bin); - } - - void SaveToBinary(dmlc::Stream* stream) final { - stream->Write(fmt_); - stream->Write(fmap_); - stream->Write(smap_); - } - std::string GetSource(const std::string& format) final { - // can only return source code. - return source_; - } - - private: - // function information table. - std::unordered_map smap_; - // function information table. - std::unordered_map fmap_; - // The format - std::string fmt_{"vulkan"}; - // The source - std::string source_; - - // Guards accesses to `ecache_` - std::mutex mutex_; - std::array>, kVulkanMaxNumDevice> - ecache_; -}; - -Module VulkanModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string source) { - auto n = make_object(smap, fmap, source); - return Module(n); -} - -VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); } - -VulkanHostVisibleBuffer* GetOrAllocate( - int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, - std::unordered_map>* buffers_ptr, - bool sync_before_realloc = false) { - auto& buffers = *buffers_ptr; - if (!buffers[device_id]) { - buffers[device_id] = std::make_unique(); - } - - auto& buf = *(buffers[device_id]); - if (buf.device != nullptr && buf.size < size) { - // free previous buffer - if (sync_before_realloc) { - // For the deferred execution mode, we need to make sure that old tasks that use - // the older, smaller buffer get finished - // Synchronization on staging buffers is done after host to device memory copy - // For UBO, we sync here before we reallocate a larger buffer, to minimize synchronization - // points - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Synchronize(); - } - DeleteHostVisibleBuffer(&buf); - } - - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - - if (buf.device == nullptr) { - buf.device = vctx.device; - } - if (buf.host_addr == nullptr) { - buf.vk_buf = CreateBuffer(vctx, size, usage, mem_type_index); - VULKAN_CALL(vkMapMemory(vctx.device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); - buf.size = size; - } - return &buf; -} - -VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - return GetOrAllocate(device_id, size, usage, vctx.staging_mtype_index, &staging_buffers_); -} - -void VulkanThreadEntry::AllocateUniformBuffer(int device_id, size_t size) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - auto info = MakeBufferCreateInfo(vctx, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); - auto mem_type_index = FindMemoryType(vctx, info, prop); - GetOrAllocate(device_id, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, mem_type_index, - &uniform_buffers_, true); -} - -VulkanUniformBuffer* VulkanThreadEntry::GetUniformBuffer(int device_id, size_t size) { - auto& buf = uniform_buffers_[device_id]; - ICHECK(buf); - ICHECK_GE(buf->size, size); - return buf.get(); -} - -VulkanThreadEntry::VulkanThreadEntry() - : pool(std::make_unique(static_cast(kDLVulkan), - VulkanDeviceAPI::Global())) { - device.device_id = 0; - device.device_type = static_cast(kDLVulkan); -} - -VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { - if (!streams_[device_id]) { - streams_[device_id] = std::unique_ptr( - new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id))); - } - return streams_[device_id].get(); -} - -void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, - const ArgUnion64* pack_args) const { - int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id; - ICHECK_LT(device_id, kVulkanMaxNumDevice); - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - if (!scache_[device_id]) { - scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); - } - const auto& pipeline = scache_[device_id]; - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - std::vector descriptor_buffers; - descriptor_buffers.resize(num_buffer_args_); - for (size_t i = 0; i < num_buffer_args_; ++i) { - void* buf = args[static_cast(i)]; - VkDescriptorBufferInfo binfo; - binfo.buffer = static_cast(buf)->buffer; - binfo.offset = 0; - binfo.range = VK_WHOLE_SIZE; - descriptor_buffers[i] = binfo; - } - const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); - if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - CHECK(ubo->host_addr) << "The UBO host buffer is not allocated"; - VkDescriptorBufferInfo binfo; - binfo.buffer = ubo->vk_buf->buffer; - binfo.offset = 0; - binfo.range = VK_WHOLE_SIZE; - descriptor_buffers.push_back(binfo); - } - if (vctx.UseImmediate()) { - // Can safely capture by reference as this lambda is immediately executed on the calling thread. - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) { - vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); - ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); - vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( - state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, - descriptor_buffers.data()); - - if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - memcpy(ubo->host_addr, pack_args, nbytes_scalars); - } else if (num_pack_args_ > 0) { - vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, - VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), - pack_args); - } - - vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT; - barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | - VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); - vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, - 1, &barrier_info, 0, nullptr, 0, nullptr); - }); - return; - } - - // Otherwise, the more expensive deferred path. - std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); - const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() { - std::vector write_descriptor_sets; - write_descriptor_sets.resize(descriptor_buffers.size()); - for (size_t i = 0; i < write_descriptor_sets.size(); i++) { - write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; - write_descriptor_sets[i].pNext = 0; - write_descriptor_sets[i].dstSet = pipeline->descriptor_set; - write_descriptor_sets[i].dstBinding = i; - write_descriptor_sets[i].dstArrayElement = 0; - write_descriptor_sets[i].descriptorCount = 1; - write_descriptor_sets[i].pImageInfo = 0; - write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]); - write_descriptor_sets[i].pTexelBufferView = 0; - - if (pipeline->use_ubo && i == write_descriptor_sets.size() - 1) { - // The last binding is for UBO - write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; - } else { - write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; - } - } - vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(), - 0, 0); - }; - const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, - device_id](VulkanStreamState* state) { - vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); - vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0, - nullptr); - - if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - memcpy(ubo->host_addr, pack_args_storage.data(), nbytes_scalars); - } else if (num_pack_args_ > 0) { - vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, - 0, pack_args_storage.size() * sizeof(ArgUnion64), - pack_args_storage.data()); - } - - vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT; - barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | - VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); - vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, - 1, &barrier_info, 0, nullptr, 0, nullptr); - }; - VulkanStreamToken deferred_token; - deferred_token.descriptor_set_ = pipeline->descriptor_set; - deferred_token.buffers_.resize(descriptor_buffers.size()); - for (size_t i = 0; i < descriptor_buffers.size(); ++i) { - deferred_token.buffers_[i] = descriptor_buffers[i].buffer; - } - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred( - deferred_initializer, deferred_kernel, deferred_token); -} - -Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) { - std::string data; - std::unordered_map smap; - std::unordered_map fmap; - std::string fmt = GetFileFormat(file_name, format); - std::string meta_file = GetMetaFilePath(file_name); - LoadBinaryFromFile(file_name, &data); - LoadMetaDataFromFile(meta_file, &fmap); - dmlc::MemoryStringStream fs(&data); - dmlc::Stream* stream = &fs; - uint32_t magic; - stream->Read(&magic); - ICHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch"; - stream->Read(&smap); - return VulkanModuleCreate(smap, fmap, ""); -} - -Module VulkanModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); - std::unordered_map smap; - std::unordered_map fmap; - - std::string fmt; - stream->Read(&fmt); - stream->Read(&fmap); - stream->Read(&smap); - return VulkanModuleCreate(smap, fmap, ""); -} - -TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); - -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); - -TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = VulkanDeviceAPI::Global(); - *rv = static_cast(ptr); -}); - -TVM_REGISTER_GLOBAL("device_api.vulkan.generate_target").set_body_typed([](int device_id) { - return VulkanDeviceAPI::Global()->GenerateTarget(device_id); -}); - -} // namespace vulkan -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_buffer.cc b/src/runtime/vulkan/vulkan_buffer.cc new file mode 100644 index 000000000000..7059e7c617f4 --- /dev/null +++ b/src/runtime/vulkan/vulkan_buffer.cc @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_buffer.h" + +#include "vulkan_device_api.h" +#include "vulkan_thread_entry.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf) { + if (buf && buf->vk_buf) { + if (buf->host_addr != nullptr) { + vkUnmapMemory(buf->device, buf->vk_buf->memory); + } + if (buf->vk_buf->memory != VK_NULL_HANDLE) { + vkFreeMemory(buf->device, buf->vk_buf->memory, nullptr); + } + if (buf->vk_buf->buffer != VK_NULL_HANDLE) { + vkDestroyBuffer(buf->device, buf->vk_buf->buffer, nullptr); + } + buf->host_addr = nullptr; + delete buf->vk_buf; + } +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_buffer.h b/src/runtime/vulkan/vulkan_buffer.h new file mode 100644 index 000000000000..77406ec2b2f8 --- /dev/null +++ b/src/runtime/vulkan/vulkan_buffer.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_BUFFER_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_BUFFER_H_ + +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace vulkan { + +struct VulkanBuffer { + VkBuffer buffer{VK_NULL_HANDLE}; + VkDeviceMemory memory{VK_NULL_HANDLE}; +}; + +/*! \brief A struct to represent Vulkan buffers backed by host visible memory */ +struct VulkanHostVisibleBuffer { + // A device where the buffer is allocated + VkDevice device{nullptr}; + // Vulkan buffer and memory + VulkanBuffer* vk_buf{nullptr}; + // The corresponding pointer to the host memory + void* host_addr{nullptr}; + // The size of the buffer in bytes + size_t size{0}; +}; + +using VulkanStagingBuffer = VulkanHostVisibleBuffer; +using VulkanUniformBuffer = VulkanHostVisibleBuffer; + +VulkanHostVisibleBuffer* GetOrAllocate( + int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, + std::unordered_map>* buffers_ptr, + bool sync_before_realloc = false); + +void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf); + +} // namespace vulkan +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_VULKAN_VULKAN_BUFFER_H_ diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 14ecdba6ca40..8fce5dbd192a 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -36,6 +36,12 @@ namespace tvm { namespace runtime { namespace vulkan { +/*! \brief Maximum number of GPU supported in VulkanModule. */ +static constexpr const int kVulkanMaxNumDevice = 8; + +/*! \brief TVM Vulkan binary pack magic number */ +static constexpr const int kVulkanModuleMagic = 0x02700027; + const int kMaxPushConstantsBytes = 128; /*! \brief A mask used when we attach additional information to shaders */ @@ -100,66 +106,6 @@ inline const char* VKGetErrorString(VkResult error) { VULKAN_CHECK_ERROR(__e); \ } -struct VulkanDescriptorTemplateKHRFunctions { - explicit VulkanDescriptorTemplateKHRFunctions(VkDevice device) { - vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR")); - vkDestroyDescriptorUpdateTemplateKHR = (PFN_vkDestroyDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkDestroyDescriptorUpdateTemplateKHR")); - vkUpdateDescriptorSetWithTemplateKHR = (PFN_vkUpdateDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkUpdateDescriptorSetWithTemplateKHR")); - vkCmdPushDescriptorSetWithTemplateKHR = - (PFN_vkCmdPushDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkCmdPushDescriptorSetWithTemplateKHR")); - } - - PFN_vkCreateDescriptorUpdateTemplateKHR vkCreateDescriptorUpdateTemplateKHR{nullptr}; - PFN_vkDestroyDescriptorUpdateTemplateKHR vkDestroyDescriptorUpdateTemplateKHR{nullptr}; - PFN_vkUpdateDescriptorSetWithTemplateKHR vkUpdateDescriptorSetWithTemplateKHR{nullptr}; - PFN_vkCmdPushDescriptorSetWithTemplateKHR vkCmdPushDescriptorSetWithTemplateKHR{nullptr}; -}; - -struct VulkanGetBufferMemoryRequirements2Functions { - explicit VulkanGetBufferMemoryRequirements2Functions(VkDevice device) { - vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2KHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); - } - - PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr}; -}; - -struct VulkanContext { - // physical device - VkPhysicalDevice phy_device{nullptr}; - - // Phyiscal device property - VkPhysicalDeviceProperties phy_device_prop; - // Target that best represents this physical device - Target target; - // Memory type index for staging. - uint32_t staging_mtype_index{0}; - // whether staging is coherent - bool coherent_staging{false}; - - std::unique_ptr descriptor_template_khr_functions{nullptr}; - std::unique_ptr - get_buffer_memory_requirements_2_functions{nullptr}; - // Memory type index for compute - uint32_t compute_mtype_index{0}; - // The logical device - VkDevice device{nullptr}; - // command queue - - std::unique_ptr queue_mutex; - VkQueue queue{nullptr}; - // queue family_index; - uint32_t queue_family_index{0}; - // Queue family index. - VkQueueFamilyProperties queue_prop; - - bool UseImmediate() const { return descriptor_template_khr_functions != nullptr; } -}; - } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_context.cc b/src/runtime/vulkan/vulkan_context.cc new file mode 100644 index 000000000000..659e6bd225f6 --- /dev/null +++ b/src/runtime/vulkan/vulkan_context.cc @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_context.h" + +#include + +#include "vulkan_common.h" +#include "vulkan_device_api.h" +#include "vulkan_thread_entry.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanDescriptorTemplateKHRFunctions::VulkanDescriptorTemplateKHRFunctions(VkDevice device) { + vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR")); + vkDestroyDescriptorUpdateTemplateKHR = (PFN_vkDestroyDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkDestroyDescriptorUpdateTemplateKHR")); + vkUpdateDescriptorSetWithTemplateKHR = (PFN_vkUpdateDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkUpdateDescriptorSetWithTemplateKHR")); + vkCmdPushDescriptorSetWithTemplateKHR = (PFN_vkCmdPushDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkCmdPushDescriptorSetWithTemplateKHR")); +} + +VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2Functions( + VkDevice device) { + vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2KHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); +} + +uint32_t FindMemoryType(const VulkanContext& vctx, VkBufferCreateInfo info, + VkMemoryPropertyFlags req_prop) { + VkBuffer buffer; + VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); + + VkMemoryRequirements mem_reqs; + vkGetBufferMemoryRequirements(vctx.device, buffer, &mem_reqs); + uint32_t type_bits = mem_reqs.memoryTypeBits; + VkPhysicalDeviceMemoryProperties phy_mem_prop; + vkGetPhysicalDeviceMemoryProperties(vctx.phy_device, &phy_mem_prop); + for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { + if ((type_bits & 1) == 1 && + (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { + return i; + } + type_bits >>= 1; + } + LOG(FATAL) << "Requested memory type not found"; + return 0; +} + +VkBufferCreateInfo MakeBufferCreateInfo(const VulkanContext& vctx, size_t nbytes, + VkBufferUsageFlags usage) { + VkBufferCreateInfo info; + info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + info.pNext = nullptr; + info.flags = 0; + info.size = nbytes; + info.queueFamilyIndexCount = 1; + info.pQueueFamilyIndices = &(vctx.queue_family_index); + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + info.usage = usage; + return info; +} + +VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index) { + auto info = MakeBufferCreateInfo(vctx, nbytes, usage); + // create buffer + VkBuffer buffer; + VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); + + // bind to memory + bool dedicated_allocation = false; + VkMemoryRequirements2KHR req2; + + if (vctx.get_buffer_memory_requirements_2_functions) { + VkBufferMemoryRequirementsInfo2KHR req_info2; + req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; + req_info2.pNext = 0; + req_info2.buffer = buffer; + + req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; + req2.pNext = 0; + + VkMemoryDedicatedRequirementsKHR dedicated_req; + dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; + dedicated_req.pNext = 0; + req2.pNext = &dedicated_req; + + vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( + vctx.device, &req_info2, &req2); + dedicated_allocation = + dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; + } + + VkDeviceMemory memory; + if (!dedicated_allocation) { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = info.size; + minfo.memoryTypeIndex = mem_type_index; + VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + } else { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = req2.memoryRequirements.size; + minfo.memoryTypeIndex = mem_type_index; + + VkMemoryDedicatedAllocateInfoKHR mdinfo; + mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; + mdinfo.pNext = 0; + mdinfo.image = 0; + mdinfo.buffer = buffer; + minfo.pNext = &mdinfo; + VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + } + VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); + VulkanBuffer* pbuf = new VulkanBuffer(); + pbuf->memory = memory; + pbuf->buffer = buffer; + return pbuf; +} + +VulkanHostVisibleBuffer* GetOrAllocate( + int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, + std::unordered_map>* buffers_ptr, + bool sync_before_realloc) { + auto& buffers = *buffers_ptr; + if (!buffers[device_id]) { + buffers[device_id] = std::make_unique(); + } + + auto& buf = *(buffers[device_id]); + if (buf.device != nullptr && buf.size < size) { + // free previous buffer + if (sync_before_realloc) { + // For the deferred execution mode, we need to make sure that old tasks that use + // the older, smaller buffer get finished + // Synchronization on staging buffers is done after host to device memory copy + // For UBO, we sync here before we reallocate a larger buffer, to minimize synchronization + // points + VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Synchronize(); + } + DeleteHostVisibleBuffer(&buf); + } + + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + + if (buf.device == nullptr) { + buf.device = vctx.device; + } + if (buf.host_addr == nullptr) { + buf.vk_buf = CreateBuffer(vctx, size, usage, mem_type_index); + VULKAN_CALL(vkMapMemory(vctx.device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); + buf.size = size; + } + return &buf; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_context.h b/src/runtime/vulkan/vulkan_context.h new file mode 100644 index 000000000000..08f0f97def14 --- /dev/null +++ b/src/runtime/vulkan/vulkan_context.h @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_CONTEXT_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_CONTEXT_H_ + +#include +#include + +#include + +#include "vulkan/vulkan_core.h" +#include "vulkan_buffer.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +struct VulkanDescriptorTemplateKHRFunctions { + explicit VulkanDescriptorTemplateKHRFunctions(VkDevice device); + + PFN_vkCreateDescriptorUpdateTemplateKHR vkCreateDescriptorUpdateTemplateKHR{nullptr}; + PFN_vkDestroyDescriptorUpdateTemplateKHR vkDestroyDescriptorUpdateTemplateKHR{nullptr}; + PFN_vkUpdateDescriptorSetWithTemplateKHR vkUpdateDescriptorSetWithTemplateKHR{nullptr}; + PFN_vkCmdPushDescriptorSetWithTemplateKHR vkCmdPushDescriptorSetWithTemplateKHR{nullptr}; +}; + +struct VulkanGetBufferMemoryRequirements2Functions { + explicit VulkanGetBufferMemoryRequirements2Functions(VkDevice device); + + PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr}; +}; + +struct VulkanContext { + // physical device + VkPhysicalDevice phy_device{nullptr}; + + // Phyiscal device property + VkPhysicalDeviceProperties phy_device_prop; + // Target that best represents this physical device + Target target; + // Memory type index for staging. + uint32_t staging_mtype_index{0}; + // whether staging is coherent + bool coherent_staging{false}; + + std::unique_ptr descriptor_template_khr_functions{nullptr}; + std::unique_ptr + get_buffer_memory_requirements_2_functions{nullptr}; + // Memory type index for compute + uint32_t compute_mtype_index{0}; + // The logical device + VkDevice device{nullptr}; + // command queue + + std::unique_ptr queue_mutex; + VkQueue queue{nullptr}; + // queue family_index; + uint32_t queue_family_index{0}; + // Queue family index. + VkQueueFamilyProperties queue_prop; + + bool UseImmediate() const { return descriptor_template_khr_functions != nullptr; } +}; + +uint32_t FindMemoryType(const VulkanContext& vctx, VkBufferCreateInfo info, + VkMemoryPropertyFlags req_prop); + +VkBufferCreateInfo MakeBufferCreateInfo(const VulkanContext& vctx, size_t nbytes, + VkBufferUsageFlags usage); + +VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index); + +} // namespace vulkan +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_VULKAN_VULKAN_CONTEXT_H_ diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc new file mode 100644 index 000000000000..d318204ce2c1 --- /dev/null +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -0,0 +1,830 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_device_api.h" + +#include +#include +#include +#include +#include + +#include "vulkan_thread_entry.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanDeviceAPI* VulkanDeviceAPI::Global() { + // Most of the TVM Global() functions allocate with "new" and do + // not deallocate, as the OS can clean up any leftover buffers at + // the end. In this case, we need the VulkanDeviceAPI destructor + // to call vkDestroyInstance, to prevent a segfault on exit when + // using some nvidia drivers. + static VulkanDeviceAPI inst; + return &inst; +} + +VulkanDeviceAPI::VulkanDeviceAPI() { + const auto layers = []() -> std::vector { + uint32_t inst_layer_prop_count; + VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr)); + std::vector inst_layer_prop(inst_layer_prop_count); + VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data())); + std::vector l; + + const char* enable = std::getenv("TVM_VULKAN_ENABLE_VALIDATION_LAYERS"); + bool validation_enabled = enable && *enable; + if (validation_enabled) { + for (const auto& lp : inst_layer_prop) { + if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) { + l.push_back("VK_LAYER_LUNARG_standard_validation"); + } + if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) { + l.push_back("VK_LAYER_LUNARG_parameter_validation"); + } + if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) { + l.push_back("VK_LAYER_KHRONOS_validation"); + } + } + } + return l; + }(); + + const auto instance_extensions = [this]() { + std::vector required_extensions{}; + std::vector optional_extensions{"VK_KHR_get_physical_device_properties2"}; + + uint32_t inst_extension_prop_count; + VULKAN_CALL( + vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr)); + std::vector inst_extension_prop(inst_extension_prop_count); + VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, + inst_extension_prop.data())); + + return FindEnabledExtensions(inst_extension_prop, required_extensions, optional_extensions); + }(); + + auto has_instance_extension = [&instance_extensions](const char* query) { + return std::any_of(instance_extensions.begin(), instance_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); + }; + + const auto instance_api_version = []() { + uint32_t api_version = VK_MAKE_VERSION(1, 0, 0); + + // Result from vkGetInstanceProcAddr is NULL if driver only + // supports vulkan 1.0. + auto vkEnumerateInstanceVersion = + (PFN_vkEnumerateInstanceVersion)vkGetInstanceProcAddr(NULL, "vkEnumerateInstanceVersion"); + if (vkEnumerateInstanceVersion) { + vkEnumerateInstanceVersion(&api_version); + } + + return api_version; + }(); + + { + VkApplicationInfo app_info; + app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + app_info.pNext = nullptr; + app_info.pApplicationName = "TVM"; + app_info.applicationVersion = 0; + app_info.pEngineName = ""; + app_info.engineVersion = 0; + app_info.apiVersion = instance_api_version; + + VkInstanceCreateInfo inst_info; + inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + inst_info.pNext = nullptr; + inst_info.flags = 0; + inst_info.pApplicationInfo = &app_info; + inst_info.enabledLayerCount = layers.size(); + inst_info.ppEnabledLayerNames = layers.data(); + inst_info.enabledExtensionCount = instance_extensions.size(); + inst_info.ppEnabledExtensionNames = instance_extensions.data(); + + VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_)); + } + + uint32_t phy_dev_count = 0; + VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr)); + std::vector all_phy_devs(phy_dev_count); + VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs))); + for (VkPhysicalDevice phy_dev : all_phy_devs) { + // Get a list of queue families supporting compute, in order of preference. We currently only + // make use of the most preferred one family. + std::vector queue_family_indexes = GetComputeQueueFamilies(phy_dev); + if (queue_family_indexes.empty()) continue; + uint32_t queue_family_index = queue_family_indexes[0]; + float priority = 1.0f; + + struct VkDeviceQueueCreateInfo queue_create_info; + queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + queue_create_info.pNext = nullptr; + queue_create_info.flags = 0; + queue_create_info.queueFamilyIndex = queue_family_index; + queue_create_info.queueCount = 1; + queue_create_info.pQueuePriorities = &priority; + + VulkanContext ctx; + // setup context + ctx.phy_device = phy_dev; + vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop)); + + const auto device_extensions = [&]() { + std::vector required_extensions{}; + std::vector optional_extensions{ + "VK_KHR_driver_properties", + "VK_KHR_storage_buffer_storage_class", + "VK_KHR_8bit_storage", + "VK_KHR_16bit_storage", + "VK_KHR_shader_float16_int8", + "VK_KHR_push_descriptor", + "VK_KHR_descriptor_update_template", + "VK_KHR_get_memory_requirements2", + "VK_KHR_dedicated_allocation", + "VK_KHR_spirv_1_4", + }; + + uint32_t device_extension_prop_count; + VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr, + &device_extension_prop_count, nullptr)); + std::vector device_extension_prop(device_extension_prop_count); + VULKAN_CALL(vkEnumerateDeviceExtensionProperties( + ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data())); + + return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions); + }(); + + ctx.target = GetDeviceDescription(instance_, phy_dev, instance_extensions, device_extensions); + + { + // Enable all features we may use that a device supports. + VkPhysicalDeviceFeatures2 enabled_features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + VkPhysicalDevice8BitStorageFeatures storage_8bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; + VkPhysicalDevice16BitStorageFeatures storage_16bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; + VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; + + void** pp_next = &enabled_features.pNext; + bool needs_float16_int8 = false; + + auto has_support = [&](const char* name) { return ctx.target->GetAttr(name).value(); }; + if (has_support("supports_float16")) { + float16_int8.shaderFloat16 = true; + needs_float16_int8 = true; + } + if (has_support("supports_float64")) { + enabled_features.features.shaderFloat64 = true; + } + if (has_support("supports_int8")) { + float16_int8.shaderInt8 = true; + needs_float16_int8 = true; + } + if (has_support("supports_int16")) { + enabled_features.features.shaderInt16 = true; + } + if (has_support("supports_int64")) { + enabled_features.features.shaderInt64 = true; + } + if (has_support("supports_8bit_buffer")) { + storage_8bit.storageBuffer8BitAccess = true; + *pp_next = &storage_8bit; + pp_next = &storage_8bit.pNext; + } + if (has_support("supports_16bit_buffer")) { + storage_16bit.storageBuffer16BitAccess = true; + *pp_next = &storage_16bit; + pp_next = &storage_16bit.pNext; + } + + if (needs_float16_int8) { + *pp_next = &float16_int8; + pp_next = &float16_int8.pNext; + } + + VkDeviceCreateInfo device_create_info; + device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + device_create_info.pNext = nullptr; + device_create_info.flags = 0; + device_create_info.queueCreateInfoCount = 1; + device_create_info.pQueueCreateInfos = &queue_create_info; + device_create_info.enabledLayerCount = 0; + device_create_info.ppEnabledLayerNames = nullptr; + device_create_info.enabledExtensionCount = device_extensions.size(); + device_create_info.ppEnabledExtensionNames = device_extensions.data(); + + if (has_instance_extension("VK_KHR_get_physical_device_properties2")) { + device_create_info.pEnabledFeatures = nullptr; + device_create_info.pNext = &enabled_features; + } else { + device_create_info.pNext = nullptr; + device_create_info.pEnabledFeatures = &enabled_features.features; + } + VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device))); + } + + ctx.queue_mutex.reset(new std::mutex()); + vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue)); + ctx.queue_family_index = queue_family_index; + // Find suitable memory type for staging and compute + // Find suitable compute index. + VkBuffer buffer; + VkMemoryRequirements req_staging, req_compute; + VkBufferCreateInfo info; + info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + info.pNext = nullptr; + info.flags = 0; + info.size = 1024; + info.queueFamilyIndexCount = 1; + info.pQueueFamilyIndices = &(ctx.queue_family_index); + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + + // get staging requirement + info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer)); + vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging); + vkDestroyBuffer(ctx.device, buffer, nullptr); + // get compute requirement + info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer)); + vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute); + vkDestroyBuffer(ctx.device, buffer, nullptr); + + // Query phyiscal device property + // find a memory that is host visible, no need to be consistent + int win_rank = -1; + VkPhysicalDeviceMemoryProperties prop; + vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop); + + for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { + VkMemoryType ty = prop.memoryTypes[k]; + size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; + // host visible + if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue; + // match copy requirment + if (!(req_staging.memoryTypeBits & (1 << k))) continue; + if (heap_size < 1024) continue; + int rank = 0; + rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + if (rank > win_rank) { + win_rank = rank; + ctx.staging_mtype_index = k; + ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + } + } + ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; + + win_rank = -1; + for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { + VkMemoryType ty = prop.memoryTypes[k]; + size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; + // host visible + if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue; + // match copy requirment + if (!(req_staging.memoryTypeBits & (1 << k))) continue; + if (heap_size < 1024) continue; + int rank = 0; + // prefer not host visible + rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); + if (rank > win_rank) { + win_rank = rank; + ctx.compute_mtype_index = k; + } + } + ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; + + if (ctx.target->GetAttr("supports_push_descriptor").value()) { + ctx.descriptor_template_khr_functions = + std::make_unique(ctx.device); + } + + if (ctx.target->GetAttr("supports_dedicated_allocation").value()) { + ctx.get_buffer_memory_requirements_2_functions = + std::make_unique(ctx.device); + } + + context_.push_back(std::move(ctx)); + } + + LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices.."; + for (size_t i = 0; i < context_.size(); ++i) { + LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName + << "\' phy_dev_id=" << context_[i].phy_device + << " use_immediate=" << context_[i].UseImmediate(); + } +} + +VulkanDeviceAPI::~VulkanDeviceAPI() { + for (auto& vctx : context_) { + vkDestroyDevice(vctx.device, nullptr); + } + if (instance_) { + vkDestroyInstance(instance_, nullptr); + } +} + +void VulkanDeviceAPI::SetDevice(Device dev) { VulkanThreadEntry::ThreadLocal()->device = dev; } + +void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { + size_t index = static_cast(dev.device_id); + if (kind == kExist) { + *rv = static_cast(index < context_.size()); + return; + } + ICHECK_LT(index, context_.size()) << "Invalid device id " << index; + + const auto& target = context(index).target; + + switch (kind) { + case kMaxThreadsPerBlock: { + *rv = target->GetAttr("max_num_threads").value(); + break; + } + case kMaxSharedMemoryPerBlock: { + *rv = target->GetAttr("max_shared_memory_per_block"); + break; + } + case kWarpSize: { + *rv = target->GetAttr("thread_warp_size").value(); + break; + } + case kComputeVersion: { + int64_t value = target->GetAttr("vulkan_api_version").value(); + std::ostringstream os; + os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." + << VK_VERSION_PATCH(value); + *rv = os.str(); + break; + } + case kDeviceName: + *rv = target->GetAttr("device_name").value(); + break; + + case kMaxClockRate: + break; + + case kMultiProcessorCount: + break; + + case kExist: + break; + + case kMaxThreadDimensions: { + std::stringstream ss; // use json string to return multiple int values; + ss << "[" << target->GetAttr("max_block_size_x").value() << ", " + << target->GetAttr("max_block_size_y").value() << ", " + << target->GetAttr("max_block_size_z").value() << "]"; + *rv = ss.str(); + break; + } + + case kMaxRegistersPerBlock: + break; + + case kGcnArch: + break; + + case kApiVersion: + *rv = VK_HEADER_VERSION; + break; + + case kDriverVersion: { + int64_t value = target->GetAttr("driver_version").value(); + std::ostringstream os; + os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." + << VK_VERSION_PATCH(value); + *rv = os.str(); + break; + } + } +} + +void* VulkanDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignment, + DLDataType type_hint) { + if (nbytes == 0) { + // Vulkan seems to have issues if we return nullptr on zero size alloc + nbytes = 1; + } + const auto& vctx = context(dev.device_id); + auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + return CreateBuffer(vctx, nbytes, usage, vctx.compute_mtype_index); +} + +void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { + // Before releasing the vkBuffer, call sync to + // finish all the vulkan commands that reference the buffer. + StreamSync(dev, nullptr); + + const auto& vctx = context(dev.device_id); + auto* pbuf = static_cast(ptr); + vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr); + vkFreeMemory(vctx.device, pbuf->memory, nullptr); + delete pbuf; +} + +void* VulkanDeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { + return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(dev, size); +} + +void VulkanDeviceAPI::FreeWorkspace(Device dev, void* data) { + VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(dev, data); +} + +TVMStreamHandle VulkanDeviceAPI::CreateStream(Device dev) { return nullptr; } + +void VulkanDeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) { + ICHECK_EQ(stream, static_cast(nullptr)); +} + +// Syncing two streams is a nop, since there is only one stream. +void VulkanDeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, + TVMStreamHandle event_dst) { + ICHECK_EQ(event_src, static_cast(nullptr)); + ICHECK_EQ(event_dst, static_cast(nullptr)); +} + +void VulkanDeviceAPI::StreamSync(Device dev, TVMStreamHandle stream) { + ICHECK_EQ(stream, static_cast(nullptr)); + VulkanThreadEntry::ThreadLocal()->Stream(dev.device_id)->Synchronize(); +} + +void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { + ICHECK_EQ(stream, static_cast(nullptr)); +} + +void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, Device dev_from, Device dev_to, + DLDataType type_hint, TVMStreamHandle stream) { + ICHECK(stream == nullptr); + Device dev = dev_from; + if (dev_from.device_type == kDLCPU) { + dev = dev_to; + } + + int from_dev_type = static_cast(dev_from.device_type); + int to_dev_type = static_cast(dev_to.device_type); + if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) { + VulkanThreadEntry::ThreadLocal() + ->Stream(dev_from.device_id) + ->Launch([=](VulkanStreamState* state) { + // 1: copy + const auto* from_buf = static_cast(from); + auto* to_buf = static_cast(to); + VkBufferCopy copy_info; + copy_info.srcOffset = from_offset; + copy_info.dstOffset = to_offset; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info); + // 2: barrier(transfer-> compute|transfer) + ICHECK_EQ(dev_from.device_id, dev_to.device_id) << "Vulkan disallow cross device copy."; + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; + barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); + vkCmdPipelineBarrier( + state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1, + &barrier_info, 0, nullptr, 0, nullptr); + }); + + } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) { + const auto* from_buf = static_cast(from); + const auto& vctx = context(dev_from.device_id); + auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_from.device_id, size); + VulkanThreadEntry::ThreadLocal() + ->Stream(dev_from.device_id) + ->Launch([&](VulkanStreamState* state) { + VkBufferCopy copy_info; + copy_info.srcOffset = from_offset; + copy_info.dstOffset = 0; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->vk_buf->buffer, 1, + ©_info); + }); + VulkanThreadEntry::ThreadLocal()->Stream(dev_from.device_id)->Synchronize(); + if (!vctx.coherent_staging) { + VkMappedMemoryRange mrange; + mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; + mrange.pNext = nullptr; + mrange.memory = temp->vk_buf->memory; + mrange.offset = 0; + mrange.size = VK_WHOLE_SIZE; // size; + VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange)); + } + memcpy(static_cast(to) + to_offset, static_cast(temp->host_addr), size); + } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) { + const auto& vctx = context(dev_to.device_id); + const auto* to_buf = static_cast(to); + VulkanStagingBuffer* temp = + VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_to.device_id, size); + memcpy(temp->host_addr, static_cast(from) + from_offset, size); + // host side flush if access is not coherent. + // so writes from CPU is visible to GPU + if (!vctx.coherent_staging) { + VkMappedMemoryRange mrange; + mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; + mrange.pNext = nullptr; + mrange.memory = temp->vk_buf->memory; + mrange.offset = 0; + mrange.size = VK_WHOLE_SIZE; // size; + VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange)); + } + + VulkanThreadEntry::ThreadLocal() + ->Stream(dev_to.device_id) + ->Launch([&](VulkanStreamState* state) { + // 0: barrier(host->transfer) + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = 0; + barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0, + nullptr); + // 1: copy + VkBufferCopy copy_info; + copy_info.srcOffset = 0; + copy_info.dstOffset = to_offset; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, temp->vk_buf->buffer, to_buf->buffer, 1, ©_info); + }); + // TODO(tulloch): should we instead make the staging buffer a property of the + // Stream? This would allow us to elide synchronizations here. + VulkanThreadEntry::ThreadLocal()->Stream(dev_to.device_id)->Synchronize(); + } else { + LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan" + << ", from=" << from_dev_type << ", to=" << to_dev_type; + } +} + +std::vector VulkanDeviceAPI::FindEnabledExtensions( + const std::vector& ext_prop, + const std::vector& required_extensions, + const std::vector& optional_extensions) { + std::set available_extensions; + for (const auto& prop : ext_prop) { + if (prop.specVersion > 0) { + available_extensions.insert(prop.extensionName); + } + } + + std::vector enabled_extensions; + for (const auto& ext : required_extensions) { + ICHECK(available_extensions.count(ext)) + << "Required vulkan extension \"" << ext << "\" not supported by driver"; + enabled_extensions.push_back(ext); + } + + for (const auto& ext : optional_extensions) { + if (available_extensions.count(ext)) { + enabled_extensions.push_back(ext); + } + } + + return enabled_extensions; +} + +const VulkanContext& VulkanDeviceAPI::context(size_t device_id) const { + ICHECK_LT(device_id, context_.size()); + return context_[device_id]; +} + +Target VulkanDeviceAPI::GenerateTarget(size_t device_id) const { return context(device_id).target; } + +std::vector VulkanDeviceAPI::GetComputeQueueFamilies(VkPhysicalDevice phy_dev) { + uint32_t queue_prop_count = 0; + vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr); + std::vector queue_props(queue_prop_count); + vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props)); + + std::vector result; + // Prefer compute-only queues. On certain devices supporting this (e.g. Mesa RADV), using + // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop). + for (uint32_t i = 0; i != queue_prop_count; ++i) { + if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && + (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) { + result.push_back(i); + } + } + // Now, push the compute queues that we skipped above into the list. + for (uint32_t i = 0; i != queue_prop_count; ++i) { + if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && + (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) { + result.push_back(i); + } + } + return result; +} + +Target VulkanDeviceAPI::GetDeviceDescription(VkInstance instance, VkPhysicalDevice dev, + const std::vector& instance_extensions, + const std::vector& device_extensions) { + auto has_extension = [&](const char* query) { + return std::any_of(device_extensions.begin(), device_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }) || + std::any_of(instance_extensions.begin(), instance_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); + }; + + // Declare output locations for properties + VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2}; + VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES}; + VkPhysicalDeviceSubgroupProperties subgroup = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES}; + + // Need to do initial query in order to check the apiVersion. + vkGetPhysicalDeviceProperties(dev, &properties.properties); + + // Set up linked list for property query + { + void** pp_next = &properties.pNext; + if (has_extension("VK_KHR_driver_properties")) { + *pp_next = &driver; + pp_next = &driver.pNext; + } + if (properties.properties.apiVersion >= VK_API_VERSION_1_1) { + *pp_next = &subgroup; + pp_next = &subgroup.pNext; + } + } + + // Declare output locations for features + VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + VkPhysicalDevice8BitStorageFeatures storage_8bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; + VkPhysicalDevice16BitStorageFeatures storage_16bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; + VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; + + // Set up linked list for feature query + { + void** pp_next = &features.pNext; + if (has_extension("VK_KHR_8bit_storage")) { + *pp_next = &storage_8bit; + pp_next = &storage_8bit.pNext; + } + if (has_extension("VK_KHR_16bit_storage")) { + *pp_next = &storage_16bit; + pp_next = &storage_16bit.pNext; + } + if (has_extension("VK_KHR_shader_float16_int8")) { + *pp_next = &float16_int8; + pp_next = &float16_int8.pNext; + } + } + + if (has_extension("VK_KHR_get_physical_device_properties2")) { + // Preferred method, call to get all properties that can be queried. + auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); + vkGetPhysicalDeviceProperties2KHR(dev, &properties); + + auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); + vkGetPhysicalDeviceFeatures2KHR(dev, &features); + } else { + // Fallback, get as many features as we can from the Vulkan1.0 + // API. Corresponding vkGetPhysicalDeviceProperties was already done earlier. + vkGetPhysicalDeviceFeatures(dev, &features.features); + } + + //// Now, extracting all the information from the vulkan query. + + // Not technically needed, because VK_SHADER_STAGE_COMPUTE_BIT will + // be set so long at least one queue has VK_QUEUE_COMPUTE_BIT, but + // preferring the explicit check. + uint32_t supported_subgroup_operations = + (subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0; + + // Even if we can't query it, warp size must be at least 1. Must + // also be defined, as `transpose` operation requires it. + uint32_t thread_warp_size = std::max(subgroup.subgroupSize, 1U); + + // By default, use the maximum API version that the driver allows, + // so that any supported features can be used by TVM shaders. + // However, if we can query the conformance version, then limit to + // only using the api version that passes the vulkan conformance + // tests. + uint32_t vulkan_api_version = properties.properties.apiVersion; + if (has_extension("VK_KHR_driver_properties")) { + auto api_major = VK_VERSION_MAJOR(vulkan_api_version); + auto api_minor = VK_VERSION_MINOR(vulkan_api_version); + if ((api_major > driver.conformanceVersion.major) || + ((api_major == driver.conformanceVersion.major) && + (api_minor > driver.conformanceVersion.minor))) { + vulkan_api_version = + VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0); + } + } + + // From "Versions and Formats" section of Vulkan spec. + uint32_t max_spirv_version = 0x10000; + if (vulkan_api_version >= VK_API_VERSION_1_2) { + max_spirv_version = 0x10500; + } else if (has_extension("VK_KHR_spirv_1_4")) { + max_spirv_version = 0x10400; + } else if (vulkan_api_version >= VK_API_VERSION_1_1) { + max_spirv_version = 0x10300; + } + + // Support is available based on these extensions, but allow it to + // be disabled based on an environment variable. + bool supports_push_descriptor = + has_extension("VK_KHR_push_descriptor") && has_extension("VK_KHR_descriptor_update_template"); + { + const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR"); + if (disable && *disable) { + supports_push_descriptor = false; + } + } + + // Support is available based on these extensions, but allow it to + // be disabled based on an environment variable. + bool supports_dedicated_allocation = has_extension("VK_KHR_get_memory_requirements2") && + has_extension("VK_KHR_dedicated_allocation"); + { + const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); + if (disable && *disable) { + supports_dedicated_allocation = false; + } + } + + Map config = { + {"kind", String("vulkan")}, + // Feature support + {"supports_float16", Bool(float16_int8.shaderFloat16)}, + {"supports_float32", Bool(true)}, + {"supports_float64", Bool(features.features.shaderFloat64)}, + {"supports_int8", Bool(float16_int8.shaderInt8)}, + {"supports_int16", Bool(features.features.shaderInt16)}, + {"supports_int32", Bool(true)}, + {"supports_int64", Bool(features.features.shaderInt64)}, + {"supports_8bit_buffer", Bool(storage_8bit.storageBuffer8BitAccess)}, + {"supports_16bit_buffer", Bool(storage_16bit.storageBuffer16BitAccess)}, + {"supports_storage_buffer_storage_class", + Bool(has_extension("VK_KHR_storage_buffer_storage_class"))}, + {"supports_push_descriptor", Bool(supports_push_descriptor)}, + {"supports_dedicated_allocation", Bool(supports_dedicated_allocation)}, + {"supported_subgroup_operations", Integer(supported_subgroup_operations)}, + // Physical device limits + {"max_num_threads", Integer(properties.properties.limits.maxComputeWorkGroupInvocations)}, + {"thread_warp_size", Integer(thread_warp_size)}, + {"max_block_size_x", Integer(properties.properties.limits.maxComputeWorkGroupSize[0])}, + {"max_block_size_y", Integer(properties.properties.limits.maxComputeWorkGroupSize[1])}, + {"max_block_size_z", Integer(properties.properties.limits.maxComputeWorkGroupSize[2])}, + {"max_push_constants_size", Integer(properties.properties.limits.maxPushConstantsSize)}, + {"max_uniform_buffer_range", Integer(properties.properties.limits.maxUniformBufferRange)}, + {"max_storage_buffer_range", + Integer(IntImm(DataType::UInt(32), properties.properties.limits.maxStorageBufferRange))}, + {"max_per_stage_descriptor_storage_buffer", + Integer(properties.properties.limits.maxPerStageDescriptorStorageBuffers)}, + {"max_shared_memory_per_block", + Integer(properties.properties.limits.maxComputeSharedMemorySize)}, + // Other device properties + {"device_name", String(properties.properties.deviceName)}, + {"driver_version", Integer(properties.properties.driverVersion)}, + {"vulkan_api_version", Integer(vulkan_api_version)}, + {"max_spirv_version", Integer(max_spirv_version)}, + }; + + return Target(config); +} + +TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = VulkanDeviceAPI::Global(); + *rv = static_cast(ptr); +}); + +TVM_REGISTER_GLOBAL("device_api.vulkan.generate_target").set_body_typed([](int device_id) { + return VulkanDeviceAPI::Global()->GenerateTarget(device_id); +}); + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h new file mode 100644 index 000000000000..d31af8945efd --- /dev/null +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_DEVICE_API_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_API_H_ + +#include + +#include + +#include "vulkan/vulkan_core.h" +#include "vulkan_context.h" +#include "vulkan_thread_entry.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +class VulkanDeviceAPI final : public DeviceAPI { + public: + static VulkanDeviceAPI* Global(); + VulkanDeviceAPI(); + ~VulkanDeviceAPI(); + + // Implement active device + void SetDevice(Device dev) final; + void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; + + // Implement memory management required by DeviceAPI + void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; + void FreeDataSpace(Device dev, void* ptr) final; + void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; + void FreeWorkspace(Device dev, void* data) final; + + // Current vulkan implementation has one "stream" per CPU thread, + // with all commands writing into a single command buffer that is + // submitted on a call to StreamSync. Therefore, for now, these are + // mostly no-ops. If needed in the future, could have multiple + // command buffers to act as multiple streams. + TVMStreamHandle CreateStream(Device dev) final; + void FreeStream(Device dev, TVMStreamHandle stream) final; + void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) final; + void StreamSync(Device dev, TVMStreamHandle stream) final; + void SetStream(Device dev, TVMStreamHandle stream) final; + + protected: + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + Device dev_from, Device dev_to, DLDataType type_hint, + TVMStreamHandle stream) final; + + // End of required methods for the DeviceAPI interface + + public: + /*! \brief Return the context associated with a specific device. + * + * These are constructed during VulkanDeviceAPI initialization, so + * this function returns immediately. + */ + const VulkanContext& context(size_t device_id) const; + + /*! \brief Get a Target that best describes a particular device. + * + * Returns the results of feature/property queries done during the + * device initialization. + */ + Target GenerateTarget(size_t device_id) const; + + private: + std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); + + Target GetDeviceDescription(VkInstance instance, VkPhysicalDevice dev, + const std::vector& instance_extensions, + const std::vector& device_extensions); + + std::vector FindEnabledExtensions( + const std::vector& ext_prop, + const std::vector& required_extensions, + const std::vector& optional_extensions); + + VkInstance instance_{nullptr}; + // The physical devices, have 1 to 1 mapping to devices + std::vector context_; +}; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_VULKAN_VULKAN_DEVICE_API_H_ diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc new file mode 100644 index 000000000000..89104d9d63d9 --- /dev/null +++ b/src/runtime/vulkan/vulkan_module.cc @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_module.h" + +#include +#include + +#include "../file_utils.h" +#include "vulkan_wrapped_func.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +Module VulkanModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string source) { + auto n = make_object(smap, fmap, source); + return Module(n); +} + +Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) { + std::string data; + std::unordered_map smap; + std::unordered_map fmap; + std::string fmt = GetFileFormat(file_name, format); + std::string meta_file = GetMetaFilePath(file_name); + LoadBinaryFromFile(file_name, &data); + LoadMetaDataFromFile(meta_file, &fmap); + dmlc::MemoryStringStream fs(&data); + dmlc::Stream* stream = &fs; + uint32_t magic; + stream->Read(&magic); + ICHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch"; + stream->Read(&smap); + return VulkanModuleCreate(smap, fmap, ""); +} + +Module VulkanModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::unordered_map smap; + std::unordered_map fmap; + + std::string fmt; + stream->Read(&fmt); + stream->Read(&fmap); + stream->Read(&smap); + return VulkanModuleCreate(smap, fmap, ""); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_stream.cc b/src/runtime/vulkan/vulkan_stream.cc new file mode 100644 index 000000000000..fee390ad7e45 --- /dev/null +++ b/src/runtime/vulkan/vulkan_stream.cc @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_stream.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanStream::VulkanStream(const VulkanContext* vctx) + : vctx_(vctx), state_(new VulkanStreamState()) { + // create command pool + VkCommandPoolCreateInfo cmd_pool_cinfo; + cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + cmd_pool_cinfo.pNext = nullptr; + cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; + cmd_pool_cinfo.queueFamilyIndex = vctx_->queue_family_index; + VULKAN_CALL(vkCreateCommandPool(vctx_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_)); + + VkCommandBufferAllocateInfo buffer_alloc_info; + buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + buffer_alloc_info.pNext = nullptr; + buffer_alloc_info.commandPool = cmd_pool_; + buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + buffer_alloc_info.commandBufferCount = 1; + VULKAN_CALL(vkAllocateCommandBuffers(vctx_->device, &buffer_alloc_info, &(state_->cmd_buffer_))); + + VkFenceCreateInfo fence_cinfo; + fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; + fence_cinfo.pNext = nullptr; + fence_cinfo.flags = 0; // VK_FENCE_CREATE_SIGNALED_BIT; + VULKAN_CALL(vkCreateFence(vctx_->device, &fence_cinfo, nullptr, &(state_->fence_))); + + VkCommandBufferBeginInfo cb_begin; + cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + cb_begin.pNext = nullptr; + cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; + cb_begin.pInheritanceInfo = 0; + VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin)); +} + +VulkanStream::~VulkanStream() { + vkDestroyFence(vctx_->device, state_->fence_, nullptr); + vkDestroyCommandPool(vctx_->device, cmd_pool_, nullptr); +} + +void VulkanStream::Launch(const std::function& kernel) { + if (vctx_->UseImmediate()) { + kernel(state_.get()); + } else { + deferred_kernels_.push_back(kernel); + } +} + +void VulkanStream::LaunchDeferred(const std::function& deferred_initializer, + const std::function& deferred_kernel, + const VulkanStreamToken& deferred_token) { + ICHECK(!vctx_->UseImmediate()); + + // If the new kernel uses the same descriptor set as one of the + // kernels already in the command buffer, we need to synchronize + // first. + if (std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), + deferred_tokens_[deferred_token.descriptor_set_].end(), + [&](const VulkanStreamToken& token) { + DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); + return token.descriptor_set_ == deferred_token.descriptor_set_ && + token.buffers_ != deferred_token.buffers_; + })) { + Synchronize(); + } + + // If the new kernel uses the same buffers in the same descriptor + // set as an already-queued kernel, we don't need to initialize it + // again. Since every VulkanWrappedFunc owns a single descriptor + // set, unless the same function is called with the same buffer + // arguments, deferred_initializer() will always be called. + if (!std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), + deferred_tokens_[deferred_token.descriptor_set_].end(), + [&](const VulkanStreamToken& token) { + DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); + return token.descriptor_set_ == deferred_token.descriptor_set_ && + token.buffers_ == deferred_token.buffers_; + })) { + deferred_initializer(); + } + + // Save the kernel itself to be called later. + deferred_kernels_.push_back(deferred_kernel); + deferred_tokens_[deferred_token.descriptor_set_].push_back(deferred_token); +} + +void VulkanStream::Synchronize() { + if (!vctx_->UseImmediate()) { + for (const auto& deferred_kernel : deferred_kernels_) { + deferred_kernel(state_.get()); + } + deferred_kernels_.clear(); + deferred_tokens_.clear(); + } else { + DCHECK_EQ(deferred_kernels_.size(), 0); + DCHECK_EQ(deferred_tokens_.size(), 0); + } + + VULKAN_CALL(vkEndCommandBuffer(state_->cmd_buffer_)); + VkSubmitInfo cb_submit; + cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; + cb_submit.pNext = nullptr; + cb_submit.waitSemaphoreCount = 0; + cb_submit.pWaitSemaphores = nullptr; + cb_submit.pWaitDstStageMask = 0; + cb_submit.commandBufferCount = 1; + cb_submit.pCommandBuffers = &(state_->cmd_buffer_); + cb_submit.signalSemaphoreCount = 0; + cb_submit.pSignalSemaphores = nullptr; + + { + // Multiple streams (on different threads) use the same VulkanContext + // instance, so we need to externally synchronize accesses. + std::lock_guard g(*(vctx_->queue_mutex)); + VULKAN_CALL(vkQueueSubmit(vctx_->queue, 1, &cb_submit, state_->fence_)); + } + uint64_t timeout = 1UL << 30UL; + VkResult res; + do { + res = vkWaitForFences(vctx_->device, 1, &(state_->fence_), 0, timeout); + } while (res == VK_TIMEOUT); + VULKAN_CHECK_ERROR(res); + VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0)); + VULKAN_CALL(vkResetFences(vctx_->device, 1, &(state_->fence_))); + + // Re-initialize the command buffer + VkCommandBufferBeginInfo cb_begin; + cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + cb_begin.pNext = nullptr; + cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; + cb_begin.pInheritanceInfo = 0; + VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin)); +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index d096a644a1f0..f328262a8b10 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -26,6 +26,7 @@ #include #include "vulkan_common.h" +#include "vulkan_context.h" namespace tvm { namespace runtime { @@ -43,135 +44,62 @@ struct VulkanStreamToken { std::vector buffers_; }; +/*! + * \brief Wrapper around a vulkan command buffer + * + * The VulkanStream collects commands into a VkCommandBuffer. When a + * newly submitted command requires resources reserved by an + * already-submitted command, all of the queued commands are + * submitted to the GPU, and the CPU waits for all queued commands to + * finish. The queued commands can also be explicitly pushed/waited + * on by calling VulkanStream::Synchronize. + * + * Currently, there exists one VulkanStream for each GPU device, for + * each CPU thread. Each time a VulkanWrappedFunc is called, it is + * submitted to the VulkanStream associated with the submitting CPU + * thread, and associated the thread-specific active device set by + * `DeviceAPI::SetDevice`. + */ class VulkanStream { public: - explicit VulkanStream(const VulkanContext* vctx) : vctx_(vctx), state_(new VulkanStreamState()) { - // create command pool - VkCommandPoolCreateInfo cmd_pool_cinfo; - cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; - cmd_pool_cinfo.pNext = nullptr; - cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; - cmd_pool_cinfo.queueFamilyIndex = vctx_->queue_family_index; - VULKAN_CALL(vkCreateCommandPool(vctx_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_)); - - VkCommandBufferAllocateInfo buffer_alloc_info; - buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; - buffer_alloc_info.pNext = nullptr; - buffer_alloc_info.commandPool = cmd_pool_; - buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; - buffer_alloc_info.commandBufferCount = 1; - VULKAN_CALL( - vkAllocateCommandBuffers(vctx_->device, &buffer_alloc_info, &(state_->cmd_buffer_))); - - VkFenceCreateInfo fence_cinfo; - fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; - fence_cinfo.pNext = nullptr; - fence_cinfo.flags = 0; // VK_FENCE_CREATE_SIGNALED_BIT; - VULKAN_CALL(vkCreateFence(vctx_->device, &fence_cinfo, nullptr, &(state_->fence_))); - - VkCommandBufferBeginInfo cb_begin; - cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; - cb_begin.pNext = nullptr; - cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; - cb_begin.pInheritanceInfo = 0; - VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin)); - } - - ~VulkanStream() { - vkDestroyFence(vctx_->device, state_->fence_, nullptr); - vkDestroyCommandPool(vctx_->device, cmd_pool_, nullptr); - } - - // Launch the kernel on the current stream. - void Launch(const std::function& kernel) { - if (vctx_->UseImmediate()) { - kernel(state_.get()); - } else { - deferred_kernels_.push_back(kernel); - } - } - - // Launch the kernel on the current stream, + explicit VulkanStream(const VulkanContext* vctx); + + ~VulkanStream(); + + /*! \brief Push the kernel onto the stream's command buffer. + * + * If context.UseImmediate() is true, the kernel is executed + * immediately to update the command buffer. Otherwise, it is added + * to the list of deferred updates to be pushed onto the command + * buffer. + * + * Assumes that there are no descriptor sets or buffers accessed by this kernel. + * + */ + void Launch(const std::function& kernel); + + /*! \brief Push the kernel onto the stream's command buffer. + * + * Can only be called if context.UseImmediate() is false. The + * kernel is delayed, and isn't pushed to the command buffer until + * all kernels are collected. + * + * \param deferred_initializer Updates the descriptor set. Only + * called if the deferred_token has differences from + * + * \param deferred_kernel Submits updates to the command buffer. + * + * \param deferred_token Indicates which descriptor set and buffers + * are accessed by this kernel. No two kernels in the command + * buffer can use the same descriptor set. + * + */ void LaunchDeferred(const std::function& deferred_initializer, const std::function& deferred_kernel, - const VulkanStreamToken& deferred_token) { - ICHECK(!vctx_->UseImmediate()); - - // It is invalid to schedule this instance on the current stream if we already - // have a matching descriptor set and a non-matching buffer set. - if (std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), - deferred_tokens_[deferred_token.descriptor_set_].end(), - [&](const VulkanStreamToken& token) { - DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); - return token.descriptor_set_ == deferred_token.descriptor_set_ && - token.buffers_ != deferred_token.buffers_; - })) { - Synchronize(); - } - - // It is unnecessary to invoke our initializer if we have a matching token. - if (!std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), - deferred_tokens_[deferred_token.descriptor_set_].end(), - [&](const VulkanStreamToken& token) { - DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); - return token.descriptor_set_ == deferred_token.descriptor_set_ && - token.buffers_ == deferred_token.buffers_; - })) { - deferred_initializer(); - } - - deferred_kernels_.push_back(deferred_kernel); - deferred_tokens_[deferred_token.descriptor_set_].push_back(deferred_token); - } + const VulkanStreamToken& deferred_token); // Synchronize the current stream `state_` with respect to the host. - void Synchronize() { - if (!vctx_->UseImmediate()) { - for (const auto& deferred_kernel : deferred_kernels_) { - deferred_kernel(state_.get()); - } - deferred_kernels_.clear(); - deferred_tokens_.clear(); - } else { - DCHECK_EQ(deferred_kernels_.size(), 0); - DCHECK_EQ(deferred_tokens_.size(), 0); - } - - VULKAN_CALL(vkEndCommandBuffer(state_->cmd_buffer_)); - VkSubmitInfo cb_submit; - cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; - cb_submit.pNext = nullptr; - cb_submit.waitSemaphoreCount = 0; - cb_submit.pWaitSemaphores = nullptr; - cb_submit.pWaitDstStageMask = 0; - cb_submit.commandBufferCount = 1; - cb_submit.pCommandBuffers = &(state_->cmd_buffer_); - cb_submit.signalSemaphoreCount = 0; - cb_submit.pSignalSemaphores = nullptr; - - { - // Multiple streams (on different threads) use the same VulkanContext - // instance, so we need to externally synchronize accesses. - std::lock_guard g(*(vctx_->queue_mutex)); - VULKAN_CALL(vkQueueSubmit(vctx_->queue, 1, &cb_submit, state_->fence_)); - } - uint64_t timeout = 1UL << 30UL; - VkResult res; - do { - res = vkWaitForFences(vctx_->device, 1, &(state_->fence_), 0, timeout); - } while (res == VK_TIMEOUT); - VULKAN_CHECK_ERROR(res); - VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0)); - VULKAN_CALL(vkResetFences(vctx_->device, 1, &(state_->fence_))); - - // Re-initialize the command buffer - VkCommandBufferBeginInfo cb_begin; - cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; - cb_begin.pNext = nullptr; - cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; - cb_begin.pInheritanceInfo = 0; - VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin)); - } + void Synchronize(); private: const VulkanContext* vctx_; diff --git a/src/runtime/vulkan/vulkan_thread_entry.cc b/src/runtime/vulkan/vulkan_thread_entry.cc new file mode 100644 index 000000000000..e7e01b9c2d06 --- /dev/null +++ b/src/runtime/vulkan/vulkan_thread_entry.cc @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_thread_entry.h" + +#include "vulkan_buffer.h" +#include "vulkan_device_api.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanThreadEntry::~VulkanThreadEntry() { + // Because the thread entry refers to Device API + // The command buffer always will be destroyed before + // the instance and device get destroyed. + // The destruction need to be manually called + // to ensure the destruction order. + + pool.reset(); + streams_.clear(); + for (const auto& kv : staging_buffers_) { + DeleteHostVisibleBuffer(kv.second.get()); + } +} + +VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); } + +void VulkanThreadEntry::AllocateUniformBuffer(int device_id, size_t size) { + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + auto info = MakeBufferCreateInfo(vctx, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); + auto mem_type_index = FindMemoryType(vctx, info, prop); + GetOrAllocate(device_id, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, mem_type_index, + &uniform_buffers_, true); +} + +VulkanUniformBuffer* VulkanThreadEntry::GetUniformBuffer(int device_id, size_t size) { + auto& buf = uniform_buffers_[device_id]; + ICHECK(buf); + ICHECK_GE(buf->size, size); + return buf.get(); +} + +VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) { + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + return GetOrAllocate(device_id, size, usage, vctx.staging_mtype_index, &staging_buffers_); +} + +VulkanThreadEntry::VulkanThreadEntry() + : pool(std::make_unique(static_cast(kDLVulkan), + VulkanDeviceAPI::Global())) { + device.device_id = 0; + device.device_type = static_cast(kDLVulkan); +} + +VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { + if (!streams_[device_id]) { + streams_[device_id] = std::unique_ptr( + new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id))); + } + return streams_[device_id].get(); +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_thread_entry.h b/src/runtime/vulkan/vulkan_thread_entry.h new file mode 100644 index 000000000000..cea5494823fd --- /dev/null +++ b/src/runtime/vulkan/vulkan_thread_entry.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_THREAD_ENTRY_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_THREAD_ENTRY_H_ + +#include + +#include +#include + +#include "../workspace_pool.h" +#include "vulkan_buffer.h" +#include "vulkan_stream.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +/*! \brief Contains all per-CPU-thread resources. + */ +class VulkanThreadEntry { + public: + VulkanThreadEntry(); + static VulkanThreadEntry* ThreadLocal(); + + ~VulkanThreadEntry(); + + Device device; + std::unique_ptr pool; + VulkanStream* Stream(size_t device_id); + VulkanStagingBuffer* StagingBuffer(int device_id, size_t size); + void AllocateUniformBuffer(int device_id, size_t size); + VulkanUniformBuffer* GetUniformBuffer(int device_id, size_t size); + + private: + //! Map from device to the VulkanStream for it + std::unordered_map> streams_; + //! Map from device to the StagingBuffer for it + std::unordered_map> staging_buffers_; + //! Map from device to the UniformBuffer associated with it + std::unordered_map> uniform_buffers_; +}; + +typedef dmlc::ThreadLocalStore VulkanThreadStore; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_VULKAN_VULKAN_THREAD_ENTRY_H_ diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc new file mode 100644 index 000000000000..2ee46b7db80c --- /dev/null +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_wrapped_func.h" + +#include + +#include + +#include "../file_utils.h" +#include "vulkan_device_api.h" +#include "vulkan_thread_entry.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, + const std::string& func_name, size_t num_buffer_args, + size_t num_pack_args, + const std::vector& thread_axis_tags) { + m_ = m; + sptr_ = sptr; + func_name_ = func_name; + num_buffer_args_ = num_buffer_args; + num_pack_args_ = num_pack_args; + thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); +} + +void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, + const ArgUnion64* pack_args) const { + int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id; + ICHECK_LT(device_id, kVulkanMaxNumDevice); + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + if (!scache_[device_id]) { + scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); + } + const auto& pipeline = scache_[device_id]; + ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + std::vector descriptor_buffers; + descriptor_buffers.resize(num_buffer_args_); + for (size_t i = 0; i < num_buffer_args_; ++i) { + void* buf = args[static_cast(i)]; + VkDescriptorBufferInfo binfo; + binfo.buffer = static_cast(buf)->buffer; + binfo.offset = 0; + binfo.range = VK_WHOLE_SIZE; + descriptor_buffers[i] = binfo; + } + const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); + if (pipeline->use_ubo) { + auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); + CHECK(ubo->host_addr) << "The UBO host buffer is not allocated"; + VkDescriptorBufferInfo binfo; + binfo.buffer = ubo->vk_buf->buffer; + binfo.offset = 0; + binfo.range = VK_WHOLE_SIZE; + descriptor_buffers.push_back(binfo); + } + if (vctx.UseImmediate()) { + // Can safely capture by reference as this lambda is immediately executed on the calling thread. + VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) { + vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); + ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); + vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( + state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, + descriptor_buffers.data()); + + if (pipeline->use_ubo) { + auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); + memcpy(ubo->host_addr, pack_args, nbytes_scalars); + } else if (num_pack_args_ > 0) { + vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, + VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), + pack_args); + } + + vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT; + barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, + 1, &barrier_info, 0, nullptr, 0, nullptr); + }); + return; + } + + // Otherwise, the more expensive deferred path. + std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); + const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() { + std::vector write_descriptor_sets; + write_descriptor_sets.resize(descriptor_buffers.size()); + for (size_t i = 0; i < write_descriptor_sets.size(); i++) { + write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; + write_descriptor_sets[i].pNext = 0; + write_descriptor_sets[i].dstSet = pipeline->descriptor_set; + write_descriptor_sets[i].dstBinding = i; + write_descriptor_sets[i].dstArrayElement = 0; + write_descriptor_sets[i].descriptorCount = 1; + write_descriptor_sets[i].pImageInfo = 0; + write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]); + write_descriptor_sets[i].pTexelBufferView = 0; + + if (pipeline->use_ubo && i == write_descriptor_sets.size() - 1) { + // The last binding is for UBO + write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; + } else { + write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + } + } + vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(), + 0, 0); + }; + const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, + device_id](VulkanStreamState* state) { + vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); + vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, + pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0, + nullptr); + + if (pipeline->use_ubo) { + auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); + memcpy(ubo->host_addr, pack_args_storage.data(), nbytes_scalars); + } else if (num_pack_args_ > 0) { + vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, + 0, pack_args_storage.size() * sizeof(ArgUnion64), + pack_args_storage.data()); + } + + vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT; + barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, + 1, &barrier_info, 0, nullptr, 0, nullptr); + }; + VulkanStreamToken deferred_token; + deferred_token.descriptor_set_ = pipeline->descriptor_set; + deferred_token.buffers_.resize(descriptor_buffers.size()); + for (size_t i = 0; i < descriptor_buffers.size(); ++i) { + deferred_token.buffers_[i] = descriptor_buffers[i].buffer; + } + VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred( + deferred_initializer, deferred_kernel, deferred_token); +} + +VulkanModuleNode::~VulkanModuleNode() { + // cleanup vulkan related caches. + for (size_t device_id = 0; device_id < ecache_.size(); ++device_id) { + for (auto& kv : ecache_[device_id]) { + auto& pe = kv.second; + ICHECK(pe); + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + + if (pe->descriptor_update_template != VK_NULL_HANDLE) { + vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR( + vctx.device, pe->descriptor_update_template, nullptr); + } + vkDestroyPipeline(vctx.device, pe->pipeline, nullptr); + vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr); + vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr); + vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr); + vkDestroyShaderModule(vctx.device, pe->shader, nullptr); + } + } +} + +PackedFunc VulkanModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + ICHECK_EQ(sptr_to_self.get(), this); + ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; + auto it = fmap_.find(name); + if (it == fmap_.end()) return PackedFunc(); + const FunctionInfo& info = it->second; + VulkanWrappedFunc f; + size_t num_buffer_args = NumBufferArgs(info.arg_types); + f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, + info.thread_axis_tags); + return PackFuncNonBufferArg(std::move(f), info.arg_types); +} + +std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, + const std::string& func_name, + size_t num_pack_args) { + const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + std::lock_guard lock(mutex_); + const auto& cp = ecache_[device_id][func_name]; + if (cp) { + return cp; + } + // Create new pipeline + auto pe = std::make_shared(); + { + // create shader + auto sit = smap_.find(func_name); + ICHECK(sit != smap_.end()); + pe->use_ubo = sit->second.flag & (1 << ShaderMetaDataFlagMask::kUseUBO); + const std::vector& data = sit->second.data; + VkShaderModuleCreateInfo shader_cinfo; + shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + shader_cinfo.pNext = nullptr; + shader_cinfo.flags = 0; + shader_cinfo.codeSize = data.size() * sizeof(uint32_t); + shader_cinfo.pCode = data.data(); + VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader))); + } + std::vector arg_binding; + std::vector arg_template; + std::vector descriptor_set_pool_sizes; + uint32_t num_pod = 0, num_buffer = 0; + + auto push_arg_info = [&arg_binding, &arg_template, &descriptor_set_pool_sizes]( + uint32_t binding, VkDescriptorType desc_type) { + { + auto result = std::find_if(descriptor_set_pool_sizes.begin(), descriptor_set_pool_sizes.end(), + [&](const auto& psize) { return psize.type == desc_type; }); + if (result == descriptor_set_pool_sizes.end()) { + VkDescriptorPoolSize new_size; + new_size.type = desc_type; + new_size.descriptorCount = 1; + descriptor_set_pool_sizes.push_back(new_size); + } else { + result->descriptorCount++; + } + } + + { + VkDescriptorSetLayoutBinding bd; + bd.binding = binding; + bd.descriptorType = desc_type; + bd.descriptorCount = 1; + bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + bd.pImmutableSamplers = nullptr; + arg_binding.push_back(bd); + } + { + VkDescriptorUpdateTemplateEntryKHR tpl; + tpl.dstBinding = binding; + tpl.dstArrayElement = 0; + tpl.descriptorCount = 1; + tpl.descriptorType = desc_type; + tpl.offset = binding * sizeof(VkDescriptorBufferInfo); + tpl.stride = sizeof(VkDescriptorBufferInfo); + arg_template.push_back(tpl); + } + }; + + { + auto fit = fmap_.find(func_name); + ICHECK(fit != fmap_.end()); + for (DLDataType arg_type : fit->second.arg_types) { + if (arg_type.code == kTVMOpaqueHandle) { + push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); + ++num_buffer; + } else { + ++num_pod; + } + } + } + + size_t nbytes_scalars = num_pod * sizeof(ArgUnion64); + if (pe->use_ubo) { + // Use UBO instead of push constants + push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); + VulkanThreadEntry::ThreadLocal()->AllocateUniformBuffer(device_id, nbytes_scalars); + } + + { + VkDescriptorSetLayoutCreateInfo descrip_cinfo; + descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; + descrip_cinfo.pNext = nullptr; + descrip_cinfo.flags = 0; + if (vctx.UseImmediate()) { + descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; + } + descrip_cinfo.bindingCount = arg_binding.size(); + descrip_cinfo.pBindings = arg_binding.data(); + VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr, + &(pe->descriptor_set_layout))); + } + + if (!vctx.UseImmediate()) { + VkDescriptorPoolCreateInfo descrip_pool_cinfo; + descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; + descrip_pool_cinfo.pNext = nullptr; + descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT; + descrip_pool_cinfo.maxSets = 1; + descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size(); + descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data(); + VULKAN_CALL( + vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr, &(pe->descriptor_pool))); + + VkDescriptorSetAllocateInfo alloc_info; + alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; + alloc_info.pNext = nullptr; + alloc_info.descriptorPool = pe->descriptor_pool; + alloc_info.descriptorSetCount = 1; + alloc_info.pSetLayouts = &(pe->descriptor_set_layout); + VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set))); + } + + VkPushConstantRange crange; + crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + crange.offset = 0; + crange.size = sizeof(ArgUnion64) * num_pack_args; + + VkPipelineLayoutCreateInfo playout_cinfo; + playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + playout_cinfo.pNext = nullptr; + playout_cinfo.flags = 0; + playout_cinfo.setLayoutCount = 1; + playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout); + + if (0 < nbytes_scalars && !pe->use_ubo) { + playout_cinfo.pushConstantRangeCount = 1; + playout_cinfo.pPushConstantRanges = &crange; + ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize); + } else { + playout_cinfo.pushConstantRangeCount = 0; + playout_cinfo.pPushConstantRanges = nullptr; + } + + VULKAN_CALL(vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout))); + + VkComputePipelineCreateInfo pipeline_cinfo; + pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + pipeline_cinfo.pNext = nullptr; + pipeline_cinfo.flags = 0; + pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + pipeline_cinfo.stage.pNext = nullptr; + pipeline_cinfo.stage.flags = 0; + pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; + pipeline_cinfo.stage.module = pe->shader; + pipeline_cinfo.stage.pName = func_name.c_str(); + pipeline_cinfo.stage.pSpecializationInfo = nullptr; + pipeline_cinfo.layout = pe->pipeline_layout; + pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE; + pipeline_cinfo.basePipelineIndex = 0; + VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, + &(pe->pipeline))); + + if (vctx.UseImmediate()) { + VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo; + descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR; + descrip_template_cinfo.pNext = 0; + descrip_template_cinfo.flags = 0; + descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size(); + descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data(); + descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR; + descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout; + descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE; + descrip_template_cinfo.pipelineLayout = pe->pipeline_layout; + descrip_template_cinfo.set = 0; + VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR( + vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template))); + } + ecache_[device_id][func_name] = pe; + return pe; +} + +void VulkanModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { + std::string fmt = GetFileFormat(file_name, format); + ICHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; + std::string meta_file = GetMetaFilePath(file_name); + SaveMetaDataToFile(meta_file, fmap_); + std::string data_bin; + dmlc::MemoryStringStream fs(&data_bin); + dmlc::Stream* stream = &fs; + uint32_t magic = kVulkanModuleMagic; + stream->Write(magic); + stream->Write(smap_); + SaveBinaryToFile(file_name, data_bin); +} + +void VulkanModuleNode::SaveToBinary(dmlc::Stream* stream) { + stream->Write(fmt_); + stream->Write(fmap_); + stream->Write(smap_); +} + +std::string VulkanModuleNode::GetSource(const std::string& format) { + // can only return disassembly code. + return source_; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h new file mode 100644 index 000000000000..be5f385316ea --- /dev/null +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_WRAPPED_FUNC_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_WRAPPED_FUNC_H_ + +#include +#include +#include +#include +#include +#include + +#include "../meta_data.h" +#include "../pack_args.h" +#include "../thread_storage_scope.h" +#include "vulkan/vulkan_core.h" +#include "vulkan_common.h" +#include "vulkan_context.h" +#include "vulkan_shader.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +struct VulkanPipeline { + VulkanContext* vctx_{nullptr}; + VkShaderModule shader{VK_NULL_HANDLE}; + VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE}; + VkDescriptorPool descriptor_pool{VK_NULL_HANDLE}; + VkDescriptorSet descriptor_set{VK_NULL_HANDLE}; + VkPipelineLayout pipeline_layout{VK_NULL_HANDLE}; + VkPipeline pipeline{VK_NULL_HANDLE}; + VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE}; + bool use_ubo{false}; +}; + +class VulkanModuleNode; + +// a wrapped function class to get packed func. +class VulkanWrappedFunc { + public: + void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_buffer_args, size_t num_pack_args, + const std::vector& thread_axis_tags); + + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; + + private: + // internal module + VulkanModuleNode* m_; + // the resource holder + ObjectPtr sptr_; + // v The name of the function. + std::string func_name_; + // Number of buffer arguments + size_t num_buffer_args_; + // number of packed arguments. + size_t num_pack_args_; + // Device state cache per device. + // mark as mutable, to enable lazy initialization + // thread axis configuration + ThreadAxisConfig thread_axis_cfg_; + + mutable std::array, kVulkanMaxNumDevice> scache_; +}; + +class VulkanModuleNode final : public runtime::ModuleNode { + public: + explicit VulkanModuleNode(std::unordered_map smap, + std::unordered_map fmap, std::string source) + : smap_(smap), fmap_(fmap), source_(source) {} + ~VulkanModuleNode(); + + const char* type_key() const final { return "vulkan"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, + size_t num_pack_args); + + void SaveToFile(const std::string& file_name, const std::string& format) final; + + void SaveToBinary(dmlc::Stream* stream) final; + std::string GetSource(const std::string& format) final; + + private: + // function information table. + std::unordered_map smap_; + // function information table. + std::unordered_map fmap_; + // The format + std::string fmt_{"vulkan"}; + // The source + std::string source_; + + // Guards accesses to `ecache_` + std::mutex mutex_; + std::array>, kVulkanMaxNumDevice> + ecache_; +}; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_VULKAN_VULKAN_WRAPPED_FUNC_H_