From 57d96471672e64ee9d04345b48360f883973ab3e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 1 Jun 2021 12:24:12 -0700 Subject: [PATCH 1/4] [Vulkan][Refactor] Broke out VkInstance setup/teardown into managed class. - Previously, the VkInstance was directly owned by the VulkanDeviceAPI. Now, VulkanDeviceAPI owns a tvm::runtime::vulkan::VulkanInstance that does setup/teardown of the VkInstance. This way, the teardown is done even if a later initialization step throws an exception. --- src/runtime/vulkan/vulkan_common.cc | 57 +++++++++ src/runtime/vulkan/vulkan_common.h | 4 + src/runtime/vulkan/vulkan_context.cc | 12 +- src/runtime/vulkan/vulkan_context.h | 5 +- src/runtime/vulkan/vulkan_device_api.cc | 124 +------------------- src/runtime/vulkan/vulkan_device_api.h | 8 +- src/runtime/vulkan/vulkan_instance.cc | 147 ++++++++++++++++++++++++ src/runtime/vulkan/vulkan_instance.h | 90 +++++++++++++++ 8 files changed, 312 insertions(+), 135 deletions(-) create mode 100644 src/runtime/vulkan/vulkan_common.cc create mode 100644 src/runtime/vulkan/vulkan_instance.cc create mode 100644 src/runtime/vulkan/vulkan_instance.h diff --git a/src/runtime/vulkan/vulkan_common.cc b/src/runtime/vulkan/vulkan_common.cc new file mode 100644 index 000000000000..30df8b86ecd5 --- /dev/null +++ b/src/runtime/vulkan/vulkan_common.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_common.h" + +#include + +namespace tvm { +namespace runtime { +namespace vulkan { + +std::vector FindEnabledExtensions( + const std::vector& ext_prop, + const std::vector& required_extensions, + const std::vector& optional_extensions) { + std::set available_extensions; + for (const auto& prop : ext_prop) { + if (prop.specVersion > 0) { + available_extensions.insert(prop.extensionName); + } + } + + std::vector enabled_extensions; + for (const auto& ext : required_extensions) { + ICHECK(available_extensions.count(ext)) + << "Required vulkan extension \"" << ext << "\" not supported by driver"; + enabled_extensions.push_back(ext); + } + + for (const auto& ext : optional_extensions) { + if (available_extensions.count(ext)) { + enabled_extensions.push_back(ext); + } + } + + return enabled_extensions; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 8fce5dbd192a..a03801cf511f 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -106,6 +106,10 @@ inline const char* VKGetErrorString(VkResult error) { VULKAN_CHECK_ERROR(__e); \ } +std::vector FindEnabledExtensions(const std::vector& ext_prop, + const std::vector& required_extensions, + const std::vector& optional_extensions); + } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_context.cc b/src/runtime/vulkan/vulkan_context.cc index 7e59c9da47b5..bdbc2838cf6e 100644 --- a/src/runtime/vulkan/vulkan_context.cc +++ b/src/runtime/vulkan/vulkan_context.cc @@ -24,20 +24,16 @@ #include "vulkan_common.h" #include "vulkan_device_api.h" +#include "vulkan_instance.h" #include "vulkan_thread_entry.h" namespace tvm { namespace runtime { namespace vulkan { -VulkanDeviceProperties::VulkanDeviceProperties(VkInstance instance, VkPhysicalDevice phy_dev, - const std::vector instance_extensions, +VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, + VkPhysicalDevice phy_dev, const std::vector device_extensions) { - auto has_instance_extension = [&](const char* query) { - return std::any_of(instance_extensions.begin(), instance_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; - auto has_device_extension = [&](const char* query) { return std::any_of(device_extensions.begin(), device_extensions.end(), [&](const char* extension) { return std::strcmp(query, extension) == 0; }); @@ -95,7 +91,7 @@ VulkanDeviceProperties::VulkanDeviceProperties(VkInstance instance, VkPhysicalDe } } - if (has_instance_extension("VK_KHR_get_physical_device_properties2")) { + if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) { // Preferred method, call to get all properties that can be queried. auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); diff --git a/src/runtime/vulkan/vulkan_context.h b/src/runtime/vulkan/vulkan_context.h index 158a53043c7b..306cbd606c44 100644 --- a/src/runtime/vulkan/vulkan_context.h +++ b/src/runtime/vulkan/vulkan_context.h @@ -34,6 +34,8 @@ namespace tvm { namespace runtime { namespace vulkan { +class VulkanInstance; + struct VulkanDescriptorTemplateKHRFunctions { explicit VulkanDescriptorTemplateKHRFunctions(VkDevice device); @@ -59,8 +61,7 @@ struct VulkanGetBufferMemoryRequirements2Functions { */ struct VulkanDeviceProperties { VulkanDeviceProperties() {} - VulkanDeviceProperties(VkInstance instance, VkPhysicalDevice phy_device, - const std::vector instance_extensions, + VulkanDeviceProperties(const VulkanInstance& instance, VkPhysicalDevice phy_dev, const std::vector device_extensions); bool supports_float16{false}; diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 7cea2489cb1b..13d7918a9532 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -25,6 +25,7 @@ #include #include +#include "vulkan_common.h" #include "vulkan_thread_entry.h" namespace tvm { @@ -42,92 +43,8 @@ VulkanDeviceAPI* VulkanDeviceAPI::Global() { } VulkanDeviceAPI::VulkanDeviceAPI() { - const auto layers = []() -> std::vector { - uint32_t inst_layer_prop_count; - VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr)); - std::vector inst_layer_prop(inst_layer_prop_count); - VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data())); - std::vector l; - - const char* enable = std::getenv("TVM_VULKAN_ENABLE_VALIDATION_LAYERS"); - bool validation_enabled = enable && *enable; - if (validation_enabled) { - for (const auto& lp : inst_layer_prop) { - if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) { - l.push_back("VK_LAYER_LUNARG_standard_validation"); - } - if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) { - l.push_back("VK_LAYER_LUNARG_parameter_validation"); - } - if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) { - l.push_back("VK_LAYER_KHRONOS_validation"); - } - } - } - return l; - }(); - - const auto instance_extensions = [this]() { - std::vector required_extensions{}; - std::vector optional_extensions{"VK_KHR_get_physical_device_properties2"}; - - uint32_t inst_extension_prop_count; - VULKAN_CALL( - vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr)); - std::vector inst_extension_prop(inst_extension_prop_count); - VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, - inst_extension_prop.data())); - - return FindEnabledExtensions(inst_extension_prop, required_extensions, optional_extensions); - }(); - - auto has_instance_extension = [&instance_extensions](const char* query) { - return std::any_of(instance_extensions.begin(), instance_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; - - const auto instance_api_version = []() { - uint32_t api_version = VK_MAKE_VERSION(1, 0, 0); - - // Result from vkGetInstanceProcAddr is NULL if driver only - // supports vulkan 1.0. - auto vkEnumerateInstanceVersion = - (PFN_vkEnumerateInstanceVersion)vkGetInstanceProcAddr(NULL, "vkEnumerateInstanceVersion"); - if (vkEnumerateInstanceVersion) { - vkEnumerateInstanceVersion(&api_version); - } - - return api_version; - }(); - - { - VkApplicationInfo app_info; - app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - app_info.pNext = nullptr; - app_info.pApplicationName = "TVM"; - app_info.applicationVersion = 0; - app_info.pEngineName = ""; - app_info.engineVersion = 0; - app_info.apiVersion = instance_api_version; - - VkInstanceCreateInfo inst_info; - inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - inst_info.pNext = nullptr; - inst_info.flags = 0; - inst_info.pApplicationInfo = &app_info; - inst_info.enabledLayerCount = layers.size(); - inst_info.ppEnabledLayerNames = layers.data(); - inst_info.enabledExtensionCount = instance_extensions.size(); - inst_info.ppEnabledExtensionNames = instance_extensions.data(); - - VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_)); - } - - uint32_t phy_dev_count = 0; - VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr)); - std::vector all_phy_devs(phy_dev_count); - VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs))); - for (VkPhysicalDevice phy_dev : all_phy_devs) { + std::vector vulkan_physical_devices = instance_.GetPhysicalDevices(); + for (VkPhysicalDevice phy_dev : vulkan_physical_devices) { // Get a list of queue families supporting compute, in order of preference. We currently only // make use of the most preferred one family. std::vector queue_family_indexes = GetComputeQueueFamilies(phy_dev); @@ -173,8 +90,7 @@ VulkanDeviceAPI::VulkanDeviceAPI() { return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions); }(); - ctx.device_properties = - VulkanDeviceProperties(instance_, phy_dev, instance_extensions, device_extensions); + ctx.device_properties = VulkanDeviceProperties(instance_, phy_dev, device_extensions); { // Enable all features we may use that a device supports. @@ -233,7 +149,7 @@ VulkanDeviceAPI::VulkanDeviceAPI() { device_create_info.enabledExtensionCount = device_extensions.size(); device_create_info.ppEnabledExtensionNames = device_extensions.data(); - if (has_instance_extension("VK_KHR_get_physical_device_properties2")) { + if (instance_.HasExtension("VK_KHR_get_physical_device_properties2")) { device_create_info.pEnabledFeatures = nullptr; device_create_info.pNext = &enabled_features; } else { @@ -339,9 +255,6 @@ VulkanDeviceAPI::~VulkanDeviceAPI() { for (auto& vctx : context_) { vkDestroyDevice(vctx.device, nullptr); } - if (instance_) { - vkDestroyInstance(instance_, nullptr); - } } void VulkanDeviceAPI::SetDevice(Device dev) { VulkanThreadEntry::ThreadLocal()->device = dev; } @@ -667,33 +580,6 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* } } -std::vector VulkanDeviceAPI::FindEnabledExtensions( - const std::vector& ext_prop, - const std::vector& required_extensions, - const std::vector& optional_extensions) { - std::set available_extensions; - for (const auto& prop : ext_prop) { - if (prop.specVersion > 0) { - available_extensions.insert(prop.extensionName); - } - } - - std::vector enabled_extensions; - for (const auto& ext : required_extensions) { - ICHECK(available_extensions.count(ext)) - << "Required vulkan extension \"" << ext << "\" not supported by driver"; - enabled_extensions.push_back(ext); - } - - for (const auto& ext : optional_extensions) { - if (available_extensions.count(ext)) { - enabled_extensions.push_back(ext); - } - } - - return enabled_extensions; -} - const VulkanContext& VulkanDeviceAPI::context(size_t device_id) const { ICHECK_LT(device_id, context_.size()) << "Requested Vulkan device_id=" << device_id << ", but only " << context_.size() << " devices present"; diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index 71c73afb0d61..27c21825fbf1 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -27,6 +27,7 @@ #include "vulkan/vulkan_core.h" #include "vulkan_context.h" +#include "vulkan_instance.h" #include "vulkan_thread_entry.h" namespace tvm { @@ -85,12 +86,7 @@ class VulkanDeviceAPI final : public DeviceAPI { private: std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); - std::vector FindEnabledExtensions( - const std::vector& ext_prop, - const std::vector& required_extensions, - const std::vector& optional_extensions); - - VkInstance instance_{nullptr}; + VulkanInstance instance_; // The physical devices, have 1 to 1 mapping to devices std::vector context_; }; diff --git a/src/runtime/vulkan/vulkan_instance.cc b/src/runtime/vulkan/vulkan_instance.cc new file mode 100644 index 000000000000..351319e0e898 --- /dev/null +++ b/src/runtime/vulkan/vulkan_instance.cc @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_instance.h" + +#include +#include + +#include "vulkan_common.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanInstance::VulkanInstance() { + const auto layers = []() { + std::vector layers; + + const char* validation_enabled_env = std::getenv("TVM_VULKAN_ENABLE_VALIDATION_LAYERS"); + bool validation_enabled = validation_enabled_env && *validation_enabled_env; + + if (validation_enabled) { + uint32_t inst_layer_prop_count; + VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr)); + std::vector inst_layer_prop(inst_layer_prop_count); + VULKAN_CALL( + vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data())); + + for (const auto& lp : inst_layer_prop) { + if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) { + layers.push_back("VK_LAYER_LUNARG_standard_validation"); + } + if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) { + layers.push_back("VK_LAYER_LUNARG_parameter_validation"); + } + if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) { + layers.push_back("VK_LAYER_KHRONOS_validation"); + } + } + } + return layers; + }(); + + { + std::vector required_extensions{}; + std::vector optional_extensions{"VK_KHR_get_physical_device_properties2"}; + + uint32_t inst_extension_prop_count; + VULKAN_CALL( + vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr)); + std::vector inst_extension_prop(inst_extension_prop_count); + VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, + inst_extension_prop.data())); + + enabled_extensions_ = + FindEnabledExtensions(inst_extension_prop, required_extensions, optional_extensions); + } + + uint32_t api_version = VK_MAKE_VERSION(1, 0, 0); + { + // Result from vkGetInstanceProcAddr is NULL if driver only + // supports vulkan 1.0. + auto vkEnumerateInstanceVersion = + (PFN_vkEnumerateInstanceVersion)vkGetInstanceProcAddr(NULL, "vkEnumerateInstanceVersion"); + if (vkEnumerateInstanceVersion) { + vkEnumerateInstanceVersion(&api_version); + } + } + + { + VkApplicationInfo app_info; + app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + app_info.pNext = nullptr; + app_info.pApplicationName = "TVM"; + app_info.applicationVersion = 0; + app_info.pEngineName = ""; + app_info.engineVersion = 0; + app_info.apiVersion = api_version; + + VkInstanceCreateInfo inst_info; + inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + inst_info.pNext = nullptr; + inst_info.flags = 0; + inst_info.pApplicationInfo = &app_info; + inst_info.enabledLayerCount = layers.size(); + inst_info.ppEnabledLayerNames = layers.data(); + inst_info.enabledExtensionCount = enabled_extensions_.size(); + inst_info.ppEnabledExtensionNames = enabled_extensions_.data(); + + VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_)); + } +} + +VulkanInstance::~VulkanInstance() { + if (instance_) { + vkDestroyInstance(instance_, nullptr); + } +} + +VulkanInstance::VulkanInstance(VulkanInstance&& other) { do_swap(std::move(other)); } + +VulkanInstance& VulkanInstance::operator=(VulkanInstance&& other) { + do_swap(std::move(other)); + return *this; +} + +void VulkanInstance::do_swap(VulkanInstance&& other) { + if (this == &other) { + return; + } + + std::swap(enabled_extensions_, other.enabled_extensions_); + std::swap(instance_, other.instance_); +} + +bool VulkanInstance::HasExtension(const char* query) const { + return std::any_of(enabled_extensions_.begin(), enabled_extensions_.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); +} + +std::vector VulkanInstance::GetPhysicalDevices() const { + uint32_t device_count = 0; + VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &device_count, nullptr)); + std::vector devices(device_count); + VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &device_count, devices.data())); + return devices; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_instance.h b/src/runtime/vulkan/vulkan_instance.h new file mode 100644 index 000000000000..06016d8f0aea --- /dev/null +++ b/src/runtime/vulkan/vulkan_instance.h @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_INSTANCE_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_INSTANCE_H_ + +#include + +#include "vulkan/vulkan_core.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +class VulkanInstance { + public: + VulkanInstance(); + ~VulkanInstance(); + + // Allow move assignment/construction + VulkanInstance(VulkanInstance&&); + VulkanInstance& operator=(VulkanInstance&&); + + // Forbid copy assignment/construction + VulkanInstance(const VulkanInstance&) = delete; + VulkanInstance& operator=(const VulkanInstance&) = delete; + + /*! \brief Expose the internal VkInstance + * + * Allows the managed class to be passed to Vulkan APIs as if it + * were the VkInstance handler itself. + */ + operator VkInstance() const { return instance_; } + + /*! \brief Checks if the device has an extension enabled + * + * Returns true if the device was initialized with the extension + * given. + * + * \param query The name of the extension to check. + */ + bool HasExtension(const char* query) const; + + /*! \brief Return all accessible physical devices + * + * Wrapper around vkEnumeratePhysicalDevices. + */ + std::vector GetPhysicalDevices() const; + + private: + /*! \brief Helper function for move assignment/construction + * + * Named "do_swap" instead of "swap" because otherwise cpplint.py + * thinks that it needs the header include. + */ + void do_swap(VulkanInstance&& other); + + /*! \brief Extensions enabled for this instance + * + * Based on supported extensions queried through + * vkEnumerateInstanceExtensionProperties, prior to creating + * instance_. Contains only statically allocated string literals, + * no cleanup required. + */ + std::vector enabled_extensions_; + + //! \brief The Vulkan API instance handle + VkInstance instance_{nullptr}; +}; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_VULKAN_VULKAN_INSTANCE_H_ From 6164ba224ba35c2e5fa73e41d83fe0ea97d1f4c6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 1 Jun 2021 15:11:00 -0700 Subject: [PATCH 2/4] [Vulkan] Renamed VulkanContext to VulkanDevice Renaming to match with the tvm.context to tvm.device rename. --- .../{vulkan_context.cc => vulkan_device.cc} | 41 +++--- .../{vulkan_context.h => vulkan_device.h} | 14 +- src/runtime/vulkan/vulkan_device_api.cc | 121 +++++++++--------- src/runtime/vulkan/vulkan_device_api.h | 8 +- src/runtime/vulkan/vulkan_stream.cc | 33 ++--- src/runtime/vulkan/vulkan_stream.h | 10 +- src/runtime/vulkan/vulkan_thread_entry.cc | 12 +- src/runtime/vulkan/vulkan_wrapped_func.cc | 57 +++++---- src/runtime/vulkan/vulkan_wrapped_func.h | 4 +- 9 files changed, 151 insertions(+), 149 deletions(-) rename src/runtime/vulkan/{vulkan_context.cc => vulkan_device.cc} (90%) rename src/runtime/vulkan/{vulkan_context.h => vulkan_device.h} (91%) diff --git a/src/runtime/vulkan/vulkan_context.cc b/src/runtime/vulkan/vulkan_device.cc similarity index 90% rename from src/runtime/vulkan/vulkan_context.cc rename to src/runtime/vulkan/vulkan_device.cc index bdbc2838cf6e..19daba24b47d 100644 --- a/src/runtime/vulkan/vulkan_context.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -17,7 +17,7 @@ * under the License. */ -#include "vulkan_context.h" +#include "vulkan_device.h" #include #include @@ -213,16 +213,16 @@ VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2F vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); } -uint32_t FindMemoryType(const VulkanContext& vctx, VkBufferCreateInfo info, +uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, VkMemoryPropertyFlags req_prop) { VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); + VULKAN_CALL(vkCreateBuffer(device.device, &info, nullptr, &buffer)); VkMemoryRequirements mem_reqs; - vkGetBufferMemoryRequirements(vctx.device, buffer, &mem_reqs); + vkGetBufferMemoryRequirements(device.device, buffer, &mem_reqs); uint32_t type_bits = mem_reqs.memoryTypeBits; VkPhysicalDeviceMemoryProperties phy_mem_prop; - vkGetPhysicalDeviceMemoryProperties(vctx.phy_device, &phy_mem_prop); + vkGetPhysicalDeviceMemoryProperties(device.phy_device, &phy_mem_prop); for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { if ((type_bits & 1) == 1 && (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { @@ -234,7 +234,7 @@ uint32_t FindMemoryType(const VulkanContext& vctx, VkBufferCreateInfo info, return 0; } -VkBufferCreateInfo MakeBufferCreateInfo(const VulkanContext& vctx, size_t nbytes, +VkBufferCreateInfo MakeBufferCreateInfo(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage) { VkBufferCreateInfo info; info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; @@ -242,24 +242,24 @@ VkBufferCreateInfo MakeBufferCreateInfo(const VulkanContext& vctx, size_t nbytes info.flags = 0; info.size = nbytes; info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(vctx.queue_family_index); + info.pQueueFamilyIndices = &(device.queue_family_index); info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; info.usage = usage; return info; } -VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage, +VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, uint32_t mem_type_index) { - auto info = MakeBufferCreateInfo(vctx, nbytes, usage); + auto info = MakeBufferCreateInfo(device, nbytes, usage); // create buffer VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); + VULKAN_CALL(vkCreateBuffer(device.device, &info, nullptr, &buffer)); // bind to memory bool dedicated_allocation = false; VkMemoryRequirements2KHR req2; - if (vctx.get_buffer_memory_requirements_2_functions) { + if (device.get_buffer_memory_requirements_2_functions) { VkBufferMemoryRequirementsInfo2KHR req_info2; req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; req_info2.pNext = 0; @@ -273,8 +273,8 @@ VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsa dedicated_req.pNext = 0; req2.pNext = &dedicated_req; - vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - vctx.device, &req_info2, &req2); + device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( + device.device, &req_info2, &req2); dedicated_allocation = dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; } @@ -286,7 +286,7 @@ VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsa minfo.pNext = nullptr; minfo.allocationSize = info.size; minfo.memoryTypeIndex = mem_type_index; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + VULKAN_CALL(vkAllocateMemory(device.device, &minfo, nullptr, &memory)); } else { VkMemoryAllocateInfo minfo; minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; @@ -300,9 +300,9 @@ VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsa mdinfo.image = 0; mdinfo.buffer = buffer; minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + VULKAN_CALL(vkAllocateMemory(device.device, &minfo, nullptr, &memory)); } - VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); + VULKAN_CALL(vkBindBufferMemory(device.device, buffer, memory, 0)); VulkanBuffer* pbuf = new VulkanBuffer(); pbuf->memory = memory; pbuf->buffer = buffer; @@ -332,14 +332,15 @@ VulkanHostVisibleBuffer* GetOrAllocate( DeleteHostVisibleBuffer(&buf); } - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& vulkan_device = VulkanDeviceAPI::Global()->device(device_id); if (buf.device == nullptr) { - buf.device = vctx.device; + buf.device = vulkan_device.device; } if (buf.host_addr == nullptr) { - buf.vk_buf = CreateBuffer(vctx, size, usage, mem_type_index); - VULKAN_CALL(vkMapMemory(vctx.device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); + buf.vk_buf = CreateBuffer(vulkan_device, size, usage, mem_type_index); + VULKAN_CALL( + vkMapMemory(vulkan_device.device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); buf.size = size; } return &buf; diff --git a/src/runtime/vulkan/vulkan_context.h b/src/runtime/vulkan/vulkan_device.h similarity index 91% rename from src/runtime/vulkan/vulkan_context.h rename to src/runtime/vulkan/vulkan_device.h index 306cbd606c44..3e5afc6f0812 100644 --- a/src/runtime/vulkan/vulkan_context.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_VULKAN_VULKAN_CONTEXT_H_ -#define TVM_RUNTIME_VULKAN_VULKAN_CONTEXT_H_ +#ifndef TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ #include #include @@ -93,7 +93,7 @@ struct VulkanDeviceProperties { uint32_t max_spirv_version{0x10000}; }; -struct VulkanContext { +struct VulkanDevice { // physical device VkPhysicalDevice phy_device{nullptr}; @@ -126,17 +126,17 @@ struct VulkanContext { bool UseImmediate() const { return descriptor_template_khr_functions != nullptr; } }; -uint32_t FindMemoryType(const VulkanContext& vctx, VkBufferCreateInfo info, +uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, VkMemoryPropertyFlags req_prop); -VkBufferCreateInfo MakeBufferCreateInfo(const VulkanContext& vctx, size_t nbytes, +VkBufferCreateInfo MakeBufferCreateInfo(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage); -VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage, +VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, uint32_t mem_type_index); } // namespace vulkan } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_VULKAN_VULKAN_CONTEXT_H_ +#endif // TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 13d7918a9532..cf2a9d0f515b 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -60,10 +60,9 @@ VulkanDeviceAPI::VulkanDeviceAPI() { queue_create_info.queueCount = 1; queue_create_info.pQueuePriorities = &priority; - VulkanContext ctx; - // setup context - ctx.phy_device = phy_dev; - vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop)); + VulkanDevice device; + device.phy_device = phy_dev; + vkGetPhysicalDeviceProperties(device.phy_device, &(device.phy_device_prop)); const auto device_extensions = [&]() { std::vector required_extensions{}; @@ -81,16 +80,16 @@ VulkanDeviceAPI::VulkanDeviceAPI() { }; uint32_t device_extension_prop_count; - VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr, + VULKAN_CALL(vkEnumerateDeviceExtensionProperties(device.phy_device, nullptr, &device_extension_prop_count, nullptr)); std::vector device_extension_prop(device_extension_prop_count); VULKAN_CALL(vkEnumerateDeviceExtensionProperties( - ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data())); + device.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data())); return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions); }(); - ctx.device_properties = VulkanDeviceProperties(instance_, phy_dev, device_extensions); + device.device_properties = VulkanDeviceProperties(instance_, phy_dev, device_extensions); { // Enable all features we may use that a device supports. @@ -105,29 +104,29 @@ VulkanDeviceAPI::VulkanDeviceAPI() { void** pp_next = &enabled_features.pNext; bool needs_float16_int8 = false; - if (ctx.device_properties.supports_float16) { + if (device.device_properties.supports_float16) { float16_int8.shaderFloat16 = true; needs_float16_int8 = true; } - if (ctx.device_properties.supports_float64) { + if (device.device_properties.supports_float64) { enabled_features.features.shaderFloat64 = true; } - if (ctx.device_properties.supports_int8) { + if (device.device_properties.supports_int8) { float16_int8.shaderInt8 = true; needs_float16_int8 = true; } - if (ctx.device_properties.supports_int16) { + if (device.device_properties.supports_int16) { enabled_features.features.shaderInt16 = true; } - if (ctx.device_properties.supports_int64) { + if (device.device_properties.supports_int64) { enabled_features.features.shaderInt64 = true; } - if (ctx.device_properties.supports_8bit_buffer) { + if (device.device_properties.supports_8bit_buffer) { storage_8bit.storageBuffer8BitAccess = true; *pp_next = &storage_8bit; pp_next = &storage_8bit.pNext; } - if (ctx.device_properties.supports_16bit_buffer) { + if (device.device_properties.supports_16bit_buffer) { storage_16bit.storageBuffer16BitAccess = true; *pp_next = &storage_16bit; pp_next = &storage_16bit.pNext; @@ -156,12 +155,12 @@ VulkanDeviceAPI::VulkanDeviceAPI() { device_create_info.pNext = nullptr; device_create_info.pEnabledFeatures = &enabled_features.features; } - VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device))); + VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(device.device))); } - ctx.queue_mutex.reset(new std::mutex()); - vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue)); - ctx.queue_family_index = queue_family_index; + device.queue_mutex.reset(new std::mutex()); + vkGetDeviceQueue(device.device, queue_family_index, 0, &(device.queue)); + device.queue_family_index = queue_family_index; // Find suitable memory type for staging and compute // Find suitable compute index. VkBuffer buffer; @@ -172,26 +171,26 @@ VulkanDeviceAPI::VulkanDeviceAPI() { info.flags = 0; info.size = 1024; info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(ctx.queue_family_index); + info.pQueueFamilyIndices = &(device.queue_family_index); info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; // get staging requirement info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer)); - vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging); - vkDestroyBuffer(ctx.device, buffer, nullptr); + VULKAN_CALL(vkCreateBuffer(device.device, &info, nullptr, &buffer)); + vkGetBufferMemoryRequirements(device.device, buffer, &req_staging); + vkDestroyBuffer(device.device, buffer, nullptr); // get compute requirement info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer)); - vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute); - vkDestroyBuffer(ctx.device, buffer, nullptr); + VULKAN_CALL(vkCreateBuffer(device.device, &info, nullptr, &buffer)); + vkGetBufferMemoryRequirements(device.device, buffer, &req_compute); + vkDestroyBuffer(device.device, buffer, nullptr); // Query phyiscal device property // find a memory that is host visible, no need to be consistent int win_rank = -1; VkPhysicalDeviceMemoryProperties prop; - vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop); + vkGetPhysicalDeviceMemoryProperties(device.phy_device, &prop); for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { VkMemoryType ty = prop.memoryTypes[k]; @@ -205,8 +204,8 @@ VulkanDeviceAPI::VulkanDeviceAPI() { rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT; if (rank > win_rank) { win_rank = rank; - ctx.staging_mtype_index = k; - ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + device.staging_mtype_index = k; + device.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; } } ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; @@ -225,35 +224,35 @@ VulkanDeviceAPI::VulkanDeviceAPI() { rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); if (rank > win_rank) { win_rank = rank; - ctx.compute_mtype_index = k; + device.compute_mtype_index = k; } } ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; - if (ctx.device_properties.supports_push_descriptor) { - ctx.descriptor_template_khr_functions = - std::make_unique(ctx.device); + if (device.device_properties.supports_push_descriptor) { + device.descriptor_template_khr_functions = + std::make_unique(device.device); } - if (ctx.device_properties.supports_dedicated_allocation) { - ctx.get_buffer_memory_requirements_2_functions = - std::make_unique(ctx.device); + if (device.device_properties.supports_dedicated_allocation) { + device.get_buffer_memory_requirements_2_functions = + std::make_unique(device.device); } - context_.push_back(std::move(ctx)); + devices_.push_back(std::move(device)); } - LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices.."; - for (size_t i = 0; i < context_.size(); ++i) { - LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName - << "\' phy_dev_id=" << context_[i].phy_device - << " use_immediate=" << context_[i].UseImmediate(); + LOG(INFO) << "Initialize Vulkan with " << devices_.size() << " devices.."; + for (size_t i = 0; i < devices_.size(); ++i) { + LOG(INFO) << "vulkan(" << i << ")=\'" << devices_[i].phy_device_prop.deviceName + << "\' phy_dev_id=" << devices_[i].phy_device + << " use_immediate=" << devices_[i].UseImmediate(); } } VulkanDeviceAPI::~VulkanDeviceAPI() { - for (auto& vctx : context_) { - vkDestroyDevice(vctx.device, nullptr); + for (auto& device : devices_) { + vkDestroyDevice(device.device, nullptr); } } @@ -262,11 +261,11 @@ void VulkanDeviceAPI::SetDevice(Device dev) { VulkanThreadEntry::ThreadLocal()-> void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { size_t index = static_cast(dev.device_id); if (kind == kExist) { - *rv = static_cast(index < context_.size()); + *rv = static_cast(index < devices_.size()); return; } - const auto& prop = context(index).device_properties; + const auto& prop = device(index).device_properties; switch (kind) { case kMaxThreadsPerBlock: { @@ -333,7 +332,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) { size_t index = static_cast(dev.device_id); - const auto& prop = context(index).device_properties; + const auto& prop = device(index).device_properties; if (property == "supports_float16") { *rv = prop.supports_float16; @@ -424,10 +423,10 @@ void* VulkanDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignmen // Vulkan seems to have issues if we return nullptr on zero size alloc nbytes = 1; } - const auto& vctx = context(dev.device_id); + const auto& device = this->device(dev.device_id); auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - return CreateBuffer(vctx, nbytes, usage, vctx.compute_mtype_index); + return CreateBuffer(device, nbytes, usage, device.compute_mtype_index); } void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { @@ -435,10 +434,10 @@ void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { // finish all the vulkan commands that reference the buffer. StreamSync(dev, nullptr); - const auto& vctx = context(dev.device_id); + const auto& device = this->device(dev.device_id); auto* pbuf = static_cast(ptr); - vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr); - vkFreeMemory(vctx.device, pbuf->memory, nullptr); + vkDestroyBuffer(device.device, pbuf->buffer, nullptr); + vkFreeMemory(device.device, pbuf->memory, nullptr); delete pbuf; } @@ -511,7 +510,7 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) { const auto* from_buf = static_cast(from); - const auto& vctx = context(dev_from.device_id); + const auto& device = this->device(dev_from.device_id); auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_from.device_id, size); VulkanThreadEntry::ThreadLocal() ->Stream(dev_from.device_id) @@ -524,32 +523,32 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* ©_info); }); VulkanThreadEntry::ThreadLocal()->Stream(dev_from.device_id)->Synchronize(); - if (!vctx.coherent_staging) { + if (!device.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; mrange.memory = temp->vk_buf->memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; - VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange)); + VULKAN_CALL(vkInvalidateMappedMemoryRanges(device.device, 1, &mrange)); } memcpy(static_cast(to) + to_offset, static_cast(temp->host_addr), size); } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) { - const auto& vctx = context(dev_to.device_id); + const auto& device = this->device(dev_to.device_id); const auto* to_buf = static_cast(to); VulkanStagingBuffer* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_to.device_id, size); memcpy(temp->host_addr, static_cast(from) + from_offset, size); // host side flush if access is not coherent. // so writes from CPU is visible to GPU - if (!vctx.coherent_staging) { + if (!device.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; mrange.memory = temp->vk_buf->memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; - VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange)); + VULKAN_CALL(vkFlushMappedMemoryRanges(device.device, 1, &mrange)); } VulkanThreadEntry::ThreadLocal() @@ -580,10 +579,10 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* } } -const VulkanContext& VulkanDeviceAPI::context(size_t device_id) const { - ICHECK_LT(device_id, context_.size()) << "Requested Vulkan device_id=" << device_id - << ", but only " << context_.size() << " devices present"; - return context_[device_id]; +const VulkanDevice& VulkanDeviceAPI::device(size_t device_id) const { + ICHECK_LT(device_id, devices_.size()) << "Requested Vulkan device_id=" << device_id + << ", but only " << devices_.size() << " devices present"; + return devices_[device_id]; } std::vector VulkanDeviceAPI::GetComputeQueueFamilies(VkPhysicalDevice phy_dev) { diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index 27c21825fbf1..cf5652a3d9c4 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -26,7 +26,7 @@ #include #include "vulkan/vulkan_core.h" -#include "vulkan_context.h" +#include "vulkan_device.h" #include "vulkan_instance.h" #include "vulkan_thread_entry.h" @@ -69,12 +69,12 @@ class VulkanDeviceAPI final : public DeviceAPI { // End of required methods for the DeviceAPI interface public: - /*! \brief Return the context associated with a specific device. + /*! \brief Return the VulkanDevice associated with a specific device_id * * These are constructed during VulkanDeviceAPI initialization, so * this function returns immediately. */ - const VulkanContext& context(size_t device_id) const; + const VulkanDevice& device(size_t device_id) const; /*! \brief Returns a property to be stored in a target. * @@ -88,7 +88,7 @@ class VulkanDeviceAPI final : public DeviceAPI { VulkanInstance instance_; // The physical devices, have 1 to 1 mapping to devices - std::vector context_; + std::vector devices_; }; } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_stream.cc b/src/runtime/vulkan/vulkan_stream.cc index fee390ad7e45..05befd0630b3 100644 --- a/src/runtime/vulkan/vulkan_stream.cc +++ b/src/runtime/vulkan/vulkan_stream.cc @@ -23,15 +23,15 @@ namespace tvm { namespace runtime { namespace vulkan { -VulkanStream::VulkanStream(const VulkanContext* vctx) - : vctx_(vctx), state_(new VulkanStreamState()) { +VulkanStream::VulkanStream(const VulkanDevice* device) + : device_(device), state_(new VulkanStreamState()) { // create command pool VkCommandPoolCreateInfo cmd_pool_cinfo; cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; cmd_pool_cinfo.pNext = nullptr; cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; - cmd_pool_cinfo.queueFamilyIndex = vctx_->queue_family_index; - VULKAN_CALL(vkCreateCommandPool(vctx_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_)); + cmd_pool_cinfo.queueFamilyIndex = device_->queue_family_index; + VULKAN_CALL(vkCreateCommandPool(device_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_)); VkCommandBufferAllocateInfo buffer_alloc_info; buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; @@ -39,13 +39,14 @@ VulkanStream::VulkanStream(const VulkanContext* vctx) buffer_alloc_info.commandPool = cmd_pool_; buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; buffer_alloc_info.commandBufferCount = 1; - VULKAN_CALL(vkAllocateCommandBuffers(vctx_->device, &buffer_alloc_info, &(state_->cmd_buffer_))); + VULKAN_CALL( + vkAllocateCommandBuffers(device_->device, &buffer_alloc_info, &(state_->cmd_buffer_))); VkFenceCreateInfo fence_cinfo; fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; fence_cinfo.pNext = nullptr; fence_cinfo.flags = 0; // VK_FENCE_CREATE_SIGNALED_BIT; - VULKAN_CALL(vkCreateFence(vctx_->device, &fence_cinfo, nullptr, &(state_->fence_))); + VULKAN_CALL(vkCreateFence(device_->device, &fence_cinfo, nullptr, &(state_->fence_))); VkCommandBufferBeginInfo cb_begin; cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; @@ -56,12 +57,12 @@ VulkanStream::VulkanStream(const VulkanContext* vctx) } VulkanStream::~VulkanStream() { - vkDestroyFence(vctx_->device, state_->fence_, nullptr); - vkDestroyCommandPool(vctx_->device, cmd_pool_, nullptr); + vkDestroyFence(device_->device, state_->fence_, nullptr); + vkDestroyCommandPool(device_->device, cmd_pool_, nullptr); } void VulkanStream::Launch(const std::function& kernel) { - if (vctx_->UseImmediate()) { + if (device_->UseImmediate()) { kernel(state_.get()); } else { deferred_kernels_.push_back(kernel); @@ -71,7 +72,7 @@ void VulkanStream::Launch(const std::function& kernel) void VulkanStream::LaunchDeferred(const std::function& deferred_initializer, const std::function& deferred_kernel, const VulkanStreamToken& deferred_token) { - ICHECK(!vctx_->UseImmediate()); + ICHECK(!device_->UseImmediate()); // If the new kernel uses the same descriptor set as one of the // kernels already in the command buffer, we need to synchronize @@ -107,7 +108,7 @@ void VulkanStream::LaunchDeferred(const std::function& deferred_initiali } void VulkanStream::Synchronize() { - if (!vctx_->UseImmediate()) { + if (!device_->UseImmediate()) { for (const auto& deferred_kernel : deferred_kernels_) { deferred_kernel(state_.get()); } @@ -131,19 +132,19 @@ void VulkanStream::Synchronize() { cb_submit.pSignalSemaphores = nullptr; { - // Multiple streams (on different threads) use the same VulkanContext + // Multiple streams (on different threads) use the same VulkanDevice // instance, so we need to externally synchronize accesses. - std::lock_guard g(*(vctx_->queue_mutex)); - VULKAN_CALL(vkQueueSubmit(vctx_->queue, 1, &cb_submit, state_->fence_)); + std::lock_guard g(*(device_->queue_mutex)); + VULKAN_CALL(vkQueueSubmit(device_->queue, 1, &cb_submit, state_->fence_)); } uint64_t timeout = 1UL << 30UL; VkResult res; do { - res = vkWaitForFences(vctx_->device, 1, &(state_->fence_), 0, timeout); + res = vkWaitForFences(device_->device, 1, &(state_->fence_), 0, timeout); } while (res == VK_TIMEOUT); VULKAN_CHECK_ERROR(res); VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0)); - VULKAN_CALL(vkResetFences(vctx_->device, 1, &(state_->fence_))); + VULKAN_CALL(vkResetFences(device_->device, 1, &(state_->fence_))); // Re-initialize the command buffer VkCommandBufferBeginInfo cb_begin; diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index f328262a8b10..ff02be4c5c35 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -26,7 +26,7 @@ #include #include "vulkan_common.h" -#include "vulkan_context.h" +#include "vulkan_device.h" namespace tvm { namespace runtime { @@ -62,13 +62,13 @@ struct VulkanStreamToken { */ class VulkanStream { public: - explicit VulkanStream(const VulkanContext* vctx); + explicit VulkanStream(const VulkanDevice* device); ~VulkanStream(); /*! \brief Push the kernel onto the stream's command buffer. * - * If context.UseImmediate() is true, the kernel is executed + * If device.UseImmediate() is true, the kernel is executed * immediately to update the command buffer. Otherwise, it is added * to the list of deferred updates to be pushed onto the command * buffer. @@ -80,7 +80,7 @@ class VulkanStream { /*! \brief Push the kernel onto the stream's command buffer. * - * Can only be called if context.UseImmediate() is false. The + * Can only be called if device.UseImmediate() is false. The * kernel is delayed, and isn't pushed to the command buffer until * all kernels are collected. * @@ -102,7 +102,7 @@ class VulkanStream { void Synchronize(); private: - const VulkanContext* vctx_; + const VulkanDevice* device_; std::unique_ptr state_; // An index of deferred tokens, allowing us to efficiently detect duplicated // deferred_initializer blocks. diff --git a/src/runtime/vulkan/vulkan_thread_entry.cc b/src/runtime/vulkan/vulkan_thread_entry.cc index e7e01b9c2d06..1e2815f31146 100644 --- a/src/runtime/vulkan/vulkan_thread_entry.cc +++ b/src/runtime/vulkan/vulkan_thread_entry.cc @@ -43,10 +43,10 @@ VulkanThreadEntry::~VulkanThreadEntry() { VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); } void VulkanThreadEntry::AllocateUniformBuffer(int device_id, size_t size) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - auto info = MakeBufferCreateInfo(vctx, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); - auto mem_type_index = FindMemoryType(vctx, info, prop); + auto info = MakeBufferCreateInfo(device, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); + auto mem_type_index = FindMemoryType(device, info, prop); GetOrAllocate(device_id, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, mem_type_index, &uniform_buffers_, true); } @@ -59,9 +59,9 @@ VulkanUniformBuffer* VulkanThreadEntry::GetUniformBuffer(int device_id, size_t s } VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - return GetOrAllocate(device_id, size, usage, vctx.staging_mtype_index, &staging_buffers_); + return GetOrAllocate(device_id, size, usage, device.staging_mtype_index, &staging_buffers_); } VulkanThreadEntry::VulkanThreadEntry() @@ -74,7 +74,7 @@ VulkanThreadEntry::VulkanThreadEntry() VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { if (!streams_[device_id]) { streams_[device_id] = std::unique_ptr( - new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id))); + new VulkanStream(&VulkanDeviceAPI::Global()->device(device_id))); } return streams_[device_id].get(); } diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 2ee46b7db80c..51bf5d486b9c 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -47,7 +47,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id; ICHECK_LT(device_id, kVulkanMaxNumDevice); - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); if (!scache_[device_id]) { scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); } @@ -73,12 +73,12 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, binfo.range = VK_WHOLE_SIZE; descriptor_buffers.push_back(binfo); } - if (vctx.UseImmediate()) { + if (device.UseImmediate()) { // Can safely capture by reference as this lambda is immediately executed on the calling thread. VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) { vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); - vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( + device.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, descriptor_buffers.data()); @@ -107,7 +107,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, // Otherwise, the more expensive deferred path. std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); - const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() { + const auto& deferred_initializer = [&device, pipeline, descriptor_buffers]() { std::vector write_descriptor_sets; write_descriptor_sets.resize(descriptor_buffers.size()); for (size_t i = 0; i < write_descriptor_sets.size(); i++) { @@ -128,8 +128,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; } } - vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(), - 0, 0); + vkUpdateDescriptorSets(device.device, write_descriptor_sets.size(), + write_descriptor_sets.data(), 0, 0); }; const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, device_id](VulkanStreamState* state) { @@ -174,17 +174,17 @@ VulkanModuleNode::~VulkanModuleNode() { for (auto& kv : ecache_[device_id]) { auto& pe = kv.second; ICHECK(pe); - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); if (pe->descriptor_update_template != VK_NULL_HANDLE) { - vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR( - vctx.device, pe->descriptor_update_template, nullptr); + device.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR( + device.device, pe->descriptor_update_template, nullptr); } - vkDestroyPipeline(vctx.device, pe->pipeline, nullptr); - vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr); - vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr); - vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr); - vkDestroyShaderModule(vctx.device, pe->shader, nullptr); + vkDestroyPipeline(device.device, pe->pipeline, nullptr); + vkDestroyPipelineLayout(device.device, pe->pipeline_layout, nullptr); + vkDestroyDescriptorPool(device.device, pe->descriptor_pool, nullptr); + vkDestroyDescriptorSetLayout(device.device, pe->descriptor_set_layout, nullptr); + vkDestroyShaderModule(device.device, pe->shader, nullptr); } } } @@ -206,7 +206,7 @@ PackedFunc VulkanModuleNode::GetFunction(const std::string& name, std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, const std::string& func_name, size_t num_pack_args) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); std::lock_guard lock(mutex_); const auto& cp = ecache_[device_id][func_name]; if (cp) { @@ -226,7 +226,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, shader_cinfo.flags = 0; shader_cinfo.codeSize = data.size() * sizeof(uint32_t); shader_cinfo.pCode = data.data(); - VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader))); + VULKAN_CALL(vkCreateShaderModule(device.device, &shader_cinfo, nullptr, &(pe->shader))); } std::vector arg_binding; std::vector arg_template; @@ -294,16 +294,16 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; descrip_cinfo.pNext = nullptr; descrip_cinfo.flags = 0; - if (vctx.UseImmediate()) { + if (device.UseImmediate()) { descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; } descrip_cinfo.bindingCount = arg_binding.size(); descrip_cinfo.pBindings = arg_binding.data(); - VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr, + VULKAN_CALL(vkCreateDescriptorSetLayout(device.device, &descrip_cinfo, nullptr, &(pe->descriptor_set_layout))); } - if (!vctx.UseImmediate()) { + if (!device.UseImmediate()) { VkDescriptorPoolCreateInfo descrip_pool_cinfo; descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; descrip_pool_cinfo.pNext = nullptr; @@ -311,8 +311,8 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, descrip_pool_cinfo.maxSets = 1; descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size(); descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data(); - VULKAN_CALL( - vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr, &(pe->descriptor_pool))); + VULKAN_CALL(vkCreateDescriptorPool(device.device, &descrip_pool_cinfo, nullptr, + &(pe->descriptor_pool))); VkDescriptorSetAllocateInfo alloc_info; alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; @@ -320,7 +320,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, alloc_info.descriptorPool = pe->descriptor_pool; alloc_info.descriptorSetCount = 1; alloc_info.pSetLayouts = &(pe->descriptor_set_layout); - VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set))); + VULKAN_CALL(vkAllocateDescriptorSets(device.device, &alloc_info, &(pe->descriptor_set))); } VkPushConstantRange crange; @@ -338,13 +338,14 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, if (0 < nbytes_scalars && !pe->use_ubo) { playout_cinfo.pushConstantRangeCount = 1; playout_cinfo.pPushConstantRanges = &crange; - ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize); + ICHECK_LE(crange.size, device.phy_device_prop.limits.maxPushConstantsSize); } else { playout_cinfo.pushConstantRangeCount = 0; playout_cinfo.pPushConstantRanges = nullptr; } - VULKAN_CALL(vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout))); + VULKAN_CALL( + vkCreatePipelineLayout(device.device, &playout_cinfo, nullptr, &(pe->pipeline_layout))); VkComputePipelineCreateInfo pipeline_cinfo; pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; @@ -360,10 +361,10 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, pipeline_cinfo.layout = pe->pipeline_layout; pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE; pipeline_cinfo.basePipelineIndex = 0; - VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, + VULKAN_CALL(vkCreateComputePipelines(device.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe->pipeline))); - if (vctx.UseImmediate()) { + if (device.UseImmediate()) { VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo; descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR; descrip_template_cinfo.pNext = 0; @@ -375,8 +376,8 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE; descrip_template_cinfo.pipelineLayout = pe->pipeline_layout; descrip_template_cinfo.set = 0; - VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR( - vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template))); + VULKAN_CALL(device.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR( + device.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template))); } ecache_[device_id][func_name] = pe; return pe; diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index be5f385316ea..a174f22eba59 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -32,7 +32,7 @@ #include "../thread_storage_scope.h" #include "vulkan/vulkan_core.h" #include "vulkan_common.h" -#include "vulkan_context.h" +#include "vulkan_device.h" #include "vulkan_shader.h" namespace tvm { @@ -40,7 +40,7 @@ namespace runtime { namespace vulkan { struct VulkanPipeline { - VulkanContext* vctx_{nullptr}; + VulkanDevice* device{nullptr}; VkShaderModule shader{VK_NULL_HANDLE}; VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE}; VkDescriptorPool descriptor_pool{VK_NULL_HANDLE}; From 99d6571587dbf0ca82bc2d066906719af59b4237 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Jun 2021 11:20:44 -0700 Subject: [PATCH 3/4] [Vulkan][Refactor] Extracted VulkanDevice initialization into VulkanDevice class --- src/runtime/vulkan/vulkan_device.cc | 346 +++++++++++++++++++--- src/runtime/vulkan/vulkan_device.h | 118 +++++++- src/runtime/vulkan/vulkan_device_api.cc | 242 +-------------- src/runtime/vulkan/vulkan_stream.cc | 23 +- src/runtime/vulkan/vulkan_wrapped_func.cc | 35 ++- 5 files changed, 450 insertions(+), 314 deletions(-) diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index 19daba24b47d..f8c9df1ff061 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -20,9 +20,12 @@ #include "vulkan_device.h" #include +#include #include +#include #include "vulkan_common.h" +#include "vulkan_device.h" #include "vulkan_device_api.h" #include "vulkan_instance.h" #include "vulkan_thread_entry.h" @@ -32,13 +35,7 @@ namespace runtime { namespace vulkan { VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, - VkPhysicalDevice phy_dev, - const std::vector device_extensions) { - auto has_device_extension = [&](const char* query) { - return std::any_of(device_extensions.begin(), device_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; - + const VulkanDevice& device) { /////////////////////////////////////////////////////////////// // Query properties from Vulkan API // /////////////////////////////////////////////////////////////// @@ -50,12 +47,12 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES}; // Need to do initial query in order to check the apiVersion. - vkGetPhysicalDeviceProperties(phy_dev, &properties.properties); + vkGetPhysicalDeviceProperties(device, &properties.properties); // Set up linked list for property query { void** pp_next = &properties.pNext; - if (has_device_extension("VK_KHR_driver_properties")) { + if (device.HasExtension("VK_KHR_driver_properties")) { *pp_next = &driver; pp_next = &driver.pNext; } @@ -77,15 +74,15 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, // Set up linked list for feature query { void** pp_next = &features.pNext; - if (has_device_extension("VK_KHR_8bit_storage")) { + if (device.HasExtension("VK_KHR_8bit_storage")) { *pp_next = &storage_8bit; pp_next = &storage_8bit.pNext; } - if (has_device_extension("VK_KHR_16bit_storage")) { + if (device.HasExtension("VK_KHR_16bit_storage")) { *pp_next = &storage_16bit; pp_next = &storage_16bit.pNext; } - if (has_device_extension("VK_KHR_shader_float16_int8")) { + if (device.HasExtension("VK_KHR_shader_float16_int8")) { *pp_next = &float16_int8; pp_next = &float16_int8.pNext; } @@ -95,15 +92,15 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, // Preferred method, call to get all properties that can be queried. auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); - vkGetPhysicalDeviceProperties2KHR(phy_dev, &properties); + vkGetPhysicalDeviceProperties2KHR(device, &properties); auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); - vkGetPhysicalDeviceFeatures2KHR(phy_dev, &features); + vkGetPhysicalDeviceFeatures2KHR(device, &features); } else { // Fallback, get as many features as we can from the Vulkan1.0 // API. Corresponding vkGetPhysicalDeviceProperties was already done earlier. - vkGetPhysicalDeviceFeatures(phy_dev, &features.features); + vkGetPhysicalDeviceFeatures(device, &features.features); } /////////////////////////////////////////////////////////////// @@ -120,12 +117,12 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, supports_8bit_buffer = storage_8bit.storageBuffer8BitAccess; supports_16bit_buffer = storage_16bit.storageBuffer16BitAccess; supports_storage_buffer_storage_class = - has_device_extension("VK_KHR_storage_buffer_storage_class"); + device.HasExtension("VK_KHR_storage_buffer_storage_class"); // Support is available based on these extensions, but allow it to // be disabled based on an environment variable. - supports_push_descriptor = has_device_extension("VK_KHR_push_descriptor") && - has_device_extension("VK_KHR_descriptor_update_template"); + supports_push_descriptor = device.HasExtension("VK_KHR_push_descriptor") && + device.HasExtension("VK_KHR_descriptor_update_template"); { const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR"); if (disable && *disable) { @@ -135,8 +132,8 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, // Support is available based on these extensions, but allow it to // be disabled based on an environment variable. - supports_dedicated_allocation = has_device_extension("VK_KHR_get_memory_requirements2") && - has_device_extension("VK_KHR_dedicated_allocation"); + supports_dedicated_allocation = device.HasExtension("VK_KHR_get_memory_requirements2") && + device.HasExtension("VK_KHR_dedicated_allocation"); { const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); if (disable && *disable) { @@ -174,7 +171,7 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, // only using the api version that passes the vulkan conformance // tests. vulkan_api_version = properties.properties.apiVersion; - if (has_device_extension("VK_KHR_driver_properties")) { + if (device.HasExtension("VK_KHR_driver_properties")) { auto api_major = VK_VERSION_MAJOR(vulkan_api_version); auto api_minor = VK_VERSION_MINOR(vulkan_api_version); if ((api_major > driver.conformanceVersion.major) || @@ -189,7 +186,7 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, max_spirv_version = 0x10000; if (vulkan_api_version >= VK_API_VERSION_1_2) { max_spirv_version = 0x10500; - } else if (has_device_extension("VK_KHR_spirv_1_4")) { + } else if (device.HasExtension("VK_KHR_spirv_1_4")) { max_spirv_version = 0x10400; } else if (vulkan_api_version >= VK_API_VERSION_1_1) { max_spirv_version = 0x10300; @@ -213,16 +210,300 @@ VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2F vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); } +VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_device) + : physical_device_(phy_device) { + vkGetPhysicalDeviceProperties(phy_device, &phy_device_prop); + + queue_family_index = SelectComputeQueueFamily(); + if (queue_family_index == uint32_t(-1)) { + // The GPU doesn't support compute, cannot use + return; + } + + enabled_extensions = SelectEnabledExtensions(); + device_properties = VulkanDeviceProperties(instance, *this); + CreateVkDevice(instance); + + // Currently, any exceptions called after this point will prevent + // vkDestroyDevice from being called in the destructor. If this + // becomes an issue, can split out the VulkanDevice into two + // classes, one of which strictly holds the VkDevice, and one which + // holds the ancillary handles that TVM needs. + + vkGetDeviceQueue(device_, queue_family_index, 0, &queue); + + // Find suitable memory type for staging and compute + // Find suitable compute index. + VkBuffer buffer; + VkMemoryRequirements req_staging, req_compute; + VkBufferCreateInfo info; + info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + info.pNext = nullptr; + info.flags = 0; + info.size = 1024; + info.queueFamilyIndexCount = 1; + info.pQueueFamilyIndices = &queue_family_index; + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + + // get staging requirement + info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + VULKAN_CALL(vkCreateBuffer(device_, &info, nullptr, &buffer)); + vkGetBufferMemoryRequirements(device_, buffer, &req_staging); + vkDestroyBuffer(device_, buffer, nullptr); + // get compute requirement + info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + VULKAN_CALL(vkCreateBuffer(device_, &info, nullptr, &buffer)); + vkGetBufferMemoryRequirements(device_, buffer, &req_compute); + vkDestroyBuffer(device_, buffer, nullptr); + + // Query phyiscal device property + // find a memory that is host visible, no need to be consistent + int win_rank = -1; + VkPhysicalDeviceMemoryProperties prop; + vkGetPhysicalDeviceMemoryProperties(physical_device_, &prop); + + for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { + VkMemoryType ty = prop.memoryTypes[k]; + size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; + // host visible + if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue; + // match copy requirment + if (!(req_staging.memoryTypeBits & (1 << k))) continue; + if (heap_size < 1024) continue; + int rank = 0; + rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + if (rank > win_rank) { + win_rank = rank; + staging_mtype_index = k; + coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + } + } + ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; + + win_rank = -1; + for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { + VkMemoryType ty = prop.memoryTypes[k]; + size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; + // host visible + if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue; + // match copy requirment + if (!(req_staging.memoryTypeBits & (1 << k))) continue; + if (heap_size < 1024) continue; + int rank = 0; + // prefer not host visible + rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); + if (rank > win_rank) { + win_rank = rank; + compute_mtype_index = k; + } + } + ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; + + if (device_properties.supports_push_descriptor) { + descriptor_template_khr_functions = + std::make_unique(device_); + } + + if (device_properties.supports_dedicated_allocation) { + get_buffer_memory_requirements_2_functions = + std::make_unique(device_); + } +} + +VulkanDevice::~VulkanDevice() { + if (device_) { + vkDestroyDevice(device_, nullptr); + } +} + +VulkanDevice::VulkanDevice(VulkanDevice&& other) { do_swap(std::move(other)); } + +VulkanDevice& VulkanDevice::operator=(VulkanDevice&& other) { + do_swap(std::move(other)); + return *this; +} + +void VulkanDevice::do_swap(VulkanDevice&& other) { + if (this == &other) { + return; + } + + std::lock(queue_mutex, other.queue_mutex); + std::lock_guard lock_self(queue_mutex, std::adopt_lock); + std::lock_guard lock_other(other.queue_mutex, std::adopt_lock); + + std::swap(device_properties, other.device_properties); + std::swap(phy_device_prop, other.phy_device_prop); + std::swap(staging_mtype_index, other.staging_mtype_index); + std::swap(coherent_staging, other.coherent_staging); + std::swap(descriptor_template_khr_functions, other.descriptor_template_khr_functions); + std::swap(get_buffer_memory_requirements_2_functions, + other.get_buffer_memory_requirements_2_functions); + std::swap(compute_mtype_index, other.compute_mtype_index); + std::swap(queue, other.queue); + std::swap(queue_family_index, other.queue_family_index); + std::swap(physical_device_, other.physical_device_); + std::swap(enabled_extensions, other.enabled_extensions); + std::swap(device_, other.device_); +} + +bool VulkanDevice::SupportsCompute() const { return queue_family_index != uint32_t(-1); } + +void VulkanDevice::QueueSubmit(VkSubmitInfo submit_info, VkFence fence) const { + // Multiple streams (on different threads) use the same VulkanDevice + // instance, so we need to externally synchronize accesses. + std::lock_guard lock(queue_mutex); + VULKAN_CALL(vkQueueSubmit(queue, 1, &submit_info, fence)); +} + +uint32_t VulkanDevice::SelectComputeQueueFamily() const { + // Get a queue family that supports compute. We currently only use + // one queue from one family. + uint32_t queue_prop_count = 0; + vkGetPhysicalDeviceQueueFamilyProperties(physical_device_, &queue_prop_count, nullptr); + std::vector queue_props(queue_prop_count); + vkGetPhysicalDeviceQueueFamilyProperties(physical_device_, &queue_prop_count, + dmlc::BeginPtr(queue_props)); + + std::vector result; + // Prefer compute-only queues. On certain devices supporting this (e.g. Mesa RADV), using + // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop). + for (uint32_t i = 0; i != queue_prop_count; ++i) { + if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && + (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) { + return i; + } + } + // Now, push the compute queues that we skipped above into the list. + for (uint32_t i = 0; i != queue_prop_count; ++i) { + if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && + (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) { + return i; + } + } + + // No queues support compute capability, this GPU cannot be used. + return -1; +} + +std::vector VulkanDevice::SelectEnabledExtensions() const { + std::vector required_extensions{}; + std::vector optional_extensions{ + "VK_KHR_driver_properties", + "VK_KHR_storage_buffer_storage_class", + "VK_KHR_8bit_storage", + "VK_KHR_16bit_storage", + "VK_KHR_shader_float16_int8", + "VK_KHR_push_descriptor", + "VK_KHR_descriptor_update_template", + "VK_KHR_get_memory_requirements2", + "VK_KHR_dedicated_allocation", + "VK_KHR_spirv_1_4", + }; + + uint32_t device_extension_prop_count; + VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr, + &device_extension_prop_count, nullptr)); + std::vector device_extension_prop(device_extension_prop_count); + VULKAN_CALL(vkEnumerateDeviceExtensionProperties( + physical_device_, nullptr, &device_extension_prop_count, device_extension_prop.data())); + + return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions); +} + +bool VulkanDevice::HasExtension(const char* query) const { + return std::any_of(enabled_extensions.begin(), enabled_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); +} + +void VulkanDevice::CreateVkDevice(const VulkanInstance& instance) { + // Enable all features we may use that a device supports. + VkPhysicalDeviceFeatures2 enabled_features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + VkPhysicalDevice8BitStorageFeatures storage_8bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; + VkPhysicalDevice16BitStorageFeatures storage_16bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; + VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; + + void** pp_next = &enabled_features.pNext; + bool needs_float16_int8 = false; + + if (device_properties.supports_float16) { + float16_int8.shaderFloat16 = true; + needs_float16_int8 = true; + } + if (device_properties.supports_float64) { + enabled_features.features.shaderFloat64 = true; + } + if (device_properties.supports_int8) { + float16_int8.shaderInt8 = true; + needs_float16_int8 = true; + } + if (device_properties.supports_int16) { + enabled_features.features.shaderInt16 = true; + } + if (device_properties.supports_int64) { + enabled_features.features.shaderInt64 = true; + } + if (device_properties.supports_8bit_buffer) { + storage_8bit.storageBuffer8BitAccess = true; + *pp_next = &storage_8bit; + pp_next = &storage_8bit.pNext; + } + if (device_properties.supports_16bit_buffer) { + storage_16bit.storageBuffer16BitAccess = true; + *pp_next = &storage_16bit; + pp_next = &storage_16bit.pNext; + } + + if (needs_float16_int8) { + *pp_next = &float16_int8; + pp_next = &float16_int8.pNext; + } + + float priority = 1.0f; + + struct VkDeviceQueueCreateInfo queue_create_info; + queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + queue_create_info.pNext = nullptr; + queue_create_info.flags = 0; + queue_create_info.queueFamilyIndex = queue_family_index; + queue_create_info.queueCount = 1; + queue_create_info.pQueuePriorities = &priority; + + VkDeviceCreateInfo device_create_info; + device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + device_create_info.pNext = nullptr; + device_create_info.flags = 0; + device_create_info.queueCreateInfoCount = 1; + device_create_info.pQueueCreateInfos = &queue_create_info; + device_create_info.enabledLayerCount = 0; + device_create_info.ppEnabledLayerNames = nullptr; + device_create_info.enabledExtensionCount = enabled_extensions.size(); + device_create_info.ppEnabledExtensionNames = enabled_extensions.data(); + + if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) { + device_create_info.pEnabledFeatures = nullptr; + device_create_info.pNext = &enabled_features; + } else { + device_create_info.pNext = nullptr; + device_create_info.pEnabledFeatures = &enabled_features.features; + } + VULKAN_CALL(vkCreateDevice(physical_device_, &device_create_info, nullptr, &device_)); +} + uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, VkMemoryPropertyFlags req_prop) { VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(device.device, &info, nullptr, &buffer)); + VULKAN_CALL(vkCreateBuffer(device, &info, nullptr, &buffer)); VkMemoryRequirements mem_reqs; - vkGetBufferMemoryRequirements(device.device, buffer, &mem_reqs); + vkGetBufferMemoryRequirements(device, buffer, &mem_reqs); uint32_t type_bits = mem_reqs.memoryTypeBits; VkPhysicalDeviceMemoryProperties phy_mem_prop; - vkGetPhysicalDeviceMemoryProperties(device.phy_device, &phy_mem_prop); + vkGetPhysicalDeviceMemoryProperties(device, &phy_mem_prop); for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { if ((type_bits & 1) == 1 && (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { @@ -253,7 +534,7 @@ VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUs auto info = MakeBufferCreateInfo(device, nbytes, usage); // create buffer VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(device.device, &info, nullptr, &buffer)); + VULKAN_CALL(vkCreateBuffer(device, &info, nullptr, &buffer)); // bind to memory bool dedicated_allocation = false; @@ -274,7 +555,7 @@ VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUs req2.pNext = &dedicated_req; device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - device.device, &req_info2, &req2); + device, &req_info2, &req2); dedicated_allocation = dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; } @@ -286,7 +567,7 @@ VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUs minfo.pNext = nullptr; minfo.allocationSize = info.size; minfo.memoryTypeIndex = mem_type_index; - VULKAN_CALL(vkAllocateMemory(device.device, &minfo, nullptr, &memory)); + VULKAN_CALL(vkAllocateMemory(device, &minfo, nullptr, &memory)); } else { VkMemoryAllocateInfo minfo; minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; @@ -300,9 +581,9 @@ VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUs mdinfo.image = 0; mdinfo.buffer = buffer; minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(device.device, &minfo, nullptr, &memory)); + VULKAN_CALL(vkAllocateMemory(device, &minfo, nullptr, &memory)); } - VULKAN_CALL(vkBindBufferMemory(device.device, buffer, memory, 0)); + VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0)); VulkanBuffer* pbuf = new VulkanBuffer(); pbuf->memory = memory; pbuf->buffer = buffer; @@ -335,12 +616,11 @@ VulkanHostVisibleBuffer* GetOrAllocate( const auto& vulkan_device = VulkanDeviceAPI::Global()->device(device_id); if (buf.device == nullptr) { - buf.device = vulkan_device.device; + buf.device = vulkan_device; } if (buf.host_addr == nullptr) { buf.vk_buf = CreateBuffer(vulkan_device, size, usage, mem_type_index); - VULKAN_CALL( - vkMapMemory(vulkan_device.device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); + VULKAN_CALL(vkMapMemory(vulkan_device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); buf.size = size; } return &buf; diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 3e5afc6f0812..a8739cb67af0 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -35,6 +35,7 @@ namespace runtime { namespace vulkan { class VulkanInstance; +class VulkanDevice; struct VulkanDescriptorTemplateKHRFunctions { explicit VulkanDescriptorTemplateKHRFunctions(VkDevice device); @@ -61,8 +62,7 @@ struct VulkanGetBufferMemoryRequirements2Functions { */ struct VulkanDeviceProperties { VulkanDeviceProperties() {} - VulkanDeviceProperties(const VulkanInstance& instance, VkPhysicalDevice phy_dev, - const std::vector device_extensions); + VulkanDeviceProperties(const VulkanInstance& instance, const VulkanDevice& device); bool supports_float16{false}; bool supports_float32{true}; @@ -93,15 +93,74 @@ struct VulkanDeviceProperties { uint32_t max_spirv_version{0x10000}; }; -struct VulkanDevice { - // physical device - VkPhysicalDevice phy_device{nullptr}; +/*! \brief Handle to the Vulkan API's VkDevice + * + * Handles all setup and teardown of the class. The owner of the + * VulkanDevice object is responsible for ensuring that it remains + * alive as long as any object that accesses that device is used. + */ +class VulkanDevice { + public: + VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_dev); + ~VulkanDevice(); + + // Allow move constructor/assignment + VulkanDevice(VulkanDevice&&); + VulkanDevice& operator=(VulkanDevice&&); + + // Disable copy constructor/assignment + VulkanDevice(const VulkanDevice&) = delete; + VulkanDevice& operator=(const VulkanDevice&) = delete; + + /*! \brief Expose the internal VkDevice + * + * Allows the managed class to be passed to Vulkan APIs as if it + * were the VkDevice handler itself. + */ + operator VkDevice() const { return device_; } + + /*! \brief Expose the internal VkPhysicalDevice + * + * Allows the managed class to be passed to Vulkan APIs as if it + * were the VkPhysicalDevice handler itself. + */ + operator VkPhysicalDevice() const { return physical_device_; } + + /*! \brief Returns whether this device supports Vulkan compute operations. + * + * If the device does not support Vulkan compute operations, it + * should not be used any further. + */ + bool SupportsCompute() const; + + /*! \brief Calls vkQueueSubmit to run work on the GPU + * + * Currently only supports submitting a single VkSubmitInfo at a + * time. Handles mutexing internally, safe to call from multiple + * CPU threads. + * + * \param submit_info The job submission information to be passed to + * vkQueueSubmit. + * + * \param fence Optional fence to be passed to vkQueueSubmit, + * signals once the command buffers submitted have completed. + */ + void QueueSubmit(VkSubmitInfo submit_info, VkFence fence) const; + + /*! \brief Checks if the device has an extension enabled + * + * Returns true if the device was initialized with the extension + * given. + * + * \param query The name of the extension to check. + */ + bool HasExtension(const char* query) const; // Cached device properties, queried through Vulkan API. - VulkanDeviceProperties device_properties; + VulkanDeviceProperties device_properties{}; // Phyiscal device property - VkPhysicalDeviceProperties phy_device_prop; + VkPhysicalDeviceProperties phy_device_prop{}; // Memory type index for staging. uint32_t staging_mtype_index{0}; // whether staging is coherent @@ -112,18 +171,47 @@ struct VulkanDevice { get_buffer_memory_requirements_2_functions{nullptr}; // Memory type index for compute uint32_t compute_mtype_index{0}; - // The logical device - VkDevice device{nullptr}; - // command queue - std::unique_ptr queue_mutex; - VkQueue queue{nullptr}; // queue family_index; - uint32_t queue_family_index{0}; - // Queue family index. - VkQueueFamilyProperties queue_prop; + uint32_t queue_family_index{uint32_t(-1)}; bool UseImmediate() const { return descriptor_template_khr_functions != nullptr; } + + private: + /*! \brief Helper function for move assignment/construction + * + * Named "do_swap" instead of "swap" because otherwise cpplint.py + * thinks that it needs the header include. + */ + void do_swap(VulkanDevice&& other); + + uint32_t SelectComputeQueueFamily() const; + std::vector SelectEnabledExtensions() const; + void CreateVkDevice(const VulkanInstance& instance); + + //! \brief Handle to the Vulkan API physical device + VkPhysicalDevice physical_device_{nullptr}; + + /*! \brief Extensions enabled for this device + * + * Based on supported extensions queried from physical_device_ prior + * to creating device_. Contains only statically allocated string + * literals, no cleanup required. + */ + std::vector enabled_extensions; + + //! \brief Handle to the Vulkan API logical device + VkDevice device_{nullptr}; + + //! \brief Mutex to protect access to queue + mutable std::mutex queue_mutex; + + /*! \brief Handle to Vulkan API VkQueue. + * + * Work can be executed by submitted to this queue using + * VulkanDevice::SubmitQueue. + */ + VkQueue queue{nullptr}; }; uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index cf2a9d0f515b..bc25f25e7e12 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -45,216 +45,15 @@ VulkanDeviceAPI* VulkanDeviceAPI::Global() { VulkanDeviceAPI::VulkanDeviceAPI() { std::vector vulkan_physical_devices = instance_.GetPhysicalDevices(); for (VkPhysicalDevice phy_dev : vulkan_physical_devices) { - // Get a list of queue families supporting compute, in order of preference. We currently only - // make use of the most preferred one family. - std::vector queue_family_indexes = GetComputeQueueFamilies(phy_dev); - if (queue_family_indexes.empty()) continue; - uint32_t queue_family_index = queue_family_indexes[0]; - float priority = 1.0f; - - struct VkDeviceQueueCreateInfo queue_create_info; - queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - queue_create_info.pNext = nullptr; - queue_create_info.flags = 0; - queue_create_info.queueFamilyIndex = queue_family_index; - queue_create_info.queueCount = 1; - queue_create_info.pQueuePriorities = &priority; - - VulkanDevice device; - device.phy_device = phy_dev; - vkGetPhysicalDeviceProperties(device.phy_device, &(device.phy_device_prop)); - - const auto device_extensions = [&]() { - std::vector required_extensions{}; - std::vector optional_extensions{ - "VK_KHR_driver_properties", - "VK_KHR_storage_buffer_storage_class", - "VK_KHR_8bit_storage", - "VK_KHR_16bit_storage", - "VK_KHR_shader_float16_int8", - "VK_KHR_push_descriptor", - "VK_KHR_descriptor_update_template", - "VK_KHR_get_memory_requirements2", - "VK_KHR_dedicated_allocation", - "VK_KHR_spirv_1_4", - }; - - uint32_t device_extension_prop_count; - VULKAN_CALL(vkEnumerateDeviceExtensionProperties(device.phy_device, nullptr, - &device_extension_prop_count, nullptr)); - std::vector device_extension_prop(device_extension_prop_count); - VULKAN_CALL(vkEnumerateDeviceExtensionProperties( - device.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data())); - - return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions); - }(); - - device.device_properties = VulkanDeviceProperties(instance_, phy_dev, device_extensions); - - { - // Enable all features we may use that a device supports. - VkPhysicalDeviceFeatures2 enabled_features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; - VkPhysicalDevice8BitStorageFeatures storage_8bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; - VkPhysicalDevice16BitStorageFeatures storage_16bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; - VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; - - void** pp_next = &enabled_features.pNext; - bool needs_float16_int8 = false; - - if (device.device_properties.supports_float16) { - float16_int8.shaderFloat16 = true; - needs_float16_int8 = true; - } - if (device.device_properties.supports_float64) { - enabled_features.features.shaderFloat64 = true; - } - if (device.device_properties.supports_int8) { - float16_int8.shaderInt8 = true; - needs_float16_int8 = true; - } - if (device.device_properties.supports_int16) { - enabled_features.features.shaderInt16 = true; - } - if (device.device_properties.supports_int64) { - enabled_features.features.shaderInt64 = true; - } - if (device.device_properties.supports_8bit_buffer) { - storage_8bit.storageBuffer8BitAccess = true; - *pp_next = &storage_8bit; - pp_next = &storage_8bit.pNext; - } - if (device.device_properties.supports_16bit_buffer) { - storage_16bit.storageBuffer16BitAccess = true; - *pp_next = &storage_16bit; - pp_next = &storage_16bit.pNext; - } - - if (needs_float16_int8) { - *pp_next = &float16_int8; - pp_next = &float16_int8.pNext; - } - - VkDeviceCreateInfo device_create_info; - device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; - device_create_info.pNext = nullptr; - device_create_info.flags = 0; - device_create_info.queueCreateInfoCount = 1; - device_create_info.pQueueCreateInfos = &queue_create_info; - device_create_info.enabledLayerCount = 0; - device_create_info.ppEnabledLayerNames = nullptr; - device_create_info.enabledExtensionCount = device_extensions.size(); - device_create_info.ppEnabledExtensionNames = device_extensions.data(); - - if (instance_.HasExtension("VK_KHR_get_physical_device_properties2")) { - device_create_info.pEnabledFeatures = nullptr; - device_create_info.pNext = &enabled_features; - } else { - device_create_info.pNext = nullptr; - device_create_info.pEnabledFeatures = &enabled_features.features; - } - VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(device.device))); - } + VulkanDevice device(instance_, phy_dev); - device.queue_mutex.reset(new std::mutex()); - vkGetDeviceQueue(device.device, queue_family_index, 0, &(device.queue)); - device.queue_family_index = queue_family_index; - // Find suitable memory type for staging and compute - // Find suitable compute index. - VkBuffer buffer; - VkMemoryRequirements req_staging, req_compute; - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = 1024; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(device.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - - // get staging requirement - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - VULKAN_CALL(vkCreateBuffer(device.device, &info, nullptr, &buffer)); - vkGetBufferMemoryRequirements(device.device, buffer, &req_staging); - vkDestroyBuffer(device.device, buffer, nullptr); - // get compute requirement - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | - VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - VULKAN_CALL(vkCreateBuffer(device.device, &info, nullptr, &buffer)); - vkGetBufferMemoryRequirements(device.device, buffer, &req_compute); - vkDestroyBuffer(device.device, buffer, nullptr); - - // Query phyiscal device property - // find a memory that is host visible, no need to be consistent - int win_rank = -1; - VkPhysicalDeviceMemoryProperties prop; - vkGetPhysicalDeviceMemoryProperties(device.phy_device, &prop); - - for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { - VkMemoryType ty = prop.memoryTypes[k]; - size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; - // host visible - if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue; - // match copy requirment - if (!(req_staging.memoryTypeBits & (1 << k))) continue; - if (heap_size < 1024) continue; - int rank = 0; - rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT; - if (rank > win_rank) { - win_rank = rank; - device.staging_mtype_index = k; - device.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - } - } - ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; - - win_rank = -1; - for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { - VkMemoryType ty = prop.memoryTypes[k]; - size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; - // host visible - if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue; - // match copy requirment - if (!(req_staging.memoryTypeBits & (1 << k))) continue; - if (heap_size < 1024) continue; - int rank = 0; - // prefer not host visible - rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); - if (rank > win_rank) { - win_rank = rank; - device.compute_mtype_index = k; - } + if (device.SupportsCompute()) { + devices_.push_back(std::move(device)); } - ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; - - if (device.device_properties.supports_push_descriptor) { - device.descriptor_template_khr_functions = - std::make_unique(device.device); - } - - if (device.device_properties.supports_dedicated_allocation) { - device.get_buffer_memory_requirements_2_functions = - std::make_unique(device.device); - } - - devices_.push_back(std::move(device)); - } - - LOG(INFO) << "Initialize Vulkan with " << devices_.size() << " devices.."; - for (size_t i = 0; i < devices_.size(); ++i) { - LOG(INFO) << "vulkan(" << i << ")=\'" << devices_[i].phy_device_prop.deviceName - << "\' phy_dev_id=" << devices_[i].phy_device - << " use_immediate=" << devices_[i].UseImmediate(); } } -VulkanDeviceAPI::~VulkanDeviceAPI() { - for (auto& device : devices_) { - vkDestroyDevice(device.device, nullptr); - } -} +VulkanDeviceAPI::~VulkanDeviceAPI() {} void VulkanDeviceAPI::SetDevice(Device dev) { VulkanThreadEntry::ThreadLocal()->device = dev; } @@ -436,8 +235,8 @@ void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { const auto& device = this->device(dev.device_id); auto* pbuf = static_cast(ptr); - vkDestroyBuffer(device.device, pbuf->buffer, nullptr); - vkFreeMemory(device.device, pbuf->memory, nullptr); + vkDestroyBuffer(device, pbuf->buffer, nullptr); + vkFreeMemory(device, pbuf->memory, nullptr); delete pbuf; } @@ -530,7 +329,7 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* mrange.memory = temp->vk_buf->memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; - VULKAN_CALL(vkInvalidateMappedMemoryRanges(device.device, 1, &mrange)); + VULKAN_CALL(vkInvalidateMappedMemoryRanges(device, 1, &mrange)); } memcpy(static_cast(to) + to_offset, static_cast(temp->host_addr), size); } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) { @@ -548,7 +347,7 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* mrange.memory = temp->vk_buf->memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; - VULKAN_CALL(vkFlushMappedMemoryRanges(device.device, 1, &mrange)); + VULKAN_CALL(vkFlushMappedMemoryRanges(device, 1, &mrange)); } VulkanThreadEntry::ThreadLocal() @@ -585,31 +384,6 @@ const VulkanDevice& VulkanDeviceAPI::device(size_t device_id) const { return devices_[device_id]; } -std::vector VulkanDeviceAPI::GetComputeQueueFamilies(VkPhysicalDevice phy_dev) { - uint32_t queue_prop_count = 0; - vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr); - std::vector queue_props(queue_prop_count); - vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props)); - - std::vector result; - // Prefer compute-only queues. On certain devices supporting this (e.g. Mesa RADV), using - // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop). - for (uint32_t i = 0; i != queue_prop_count; ++i) { - if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && - (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) { - result.push_back(i); - } - } - // Now, push the compute queues that we skipped above into the list. - for (uint32_t i = 0; i != queue_prop_count; ++i) { - if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && - (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) { - result.push_back(i); - } - } - return result; -} - TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = VulkanDeviceAPI::Global(); *rv = static_cast(ptr); diff --git a/src/runtime/vulkan/vulkan_stream.cc b/src/runtime/vulkan/vulkan_stream.cc index 05befd0630b3..9784ee78503d 100644 --- a/src/runtime/vulkan/vulkan_stream.cc +++ b/src/runtime/vulkan/vulkan_stream.cc @@ -31,7 +31,7 @@ VulkanStream::VulkanStream(const VulkanDevice* device) cmd_pool_cinfo.pNext = nullptr; cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; cmd_pool_cinfo.queueFamilyIndex = device_->queue_family_index; - VULKAN_CALL(vkCreateCommandPool(device_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_)); + VULKAN_CALL(vkCreateCommandPool(*device_, &cmd_pool_cinfo, nullptr, &cmd_pool_)); VkCommandBufferAllocateInfo buffer_alloc_info; buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; @@ -39,14 +39,13 @@ VulkanStream::VulkanStream(const VulkanDevice* device) buffer_alloc_info.commandPool = cmd_pool_; buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; buffer_alloc_info.commandBufferCount = 1; - VULKAN_CALL( - vkAllocateCommandBuffers(device_->device, &buffer_alloc_info, &(state_->cmd_buffer_))); + VULKAN_CALL(vkAllocateCommandBuffers(*device_, &buffer_alloc_info, &(state_->cmd_buffer_))); VkFenceCreateInfo fence_cinfo; fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; fence_cinfo.pNext = nullptr; fence_cinfo.flags = 0; // VK_FENCE_CREATE_SIGNALED_BIT; - VULKAN_CALL(vkCreateFence(device_->device, &fence_cinfo, nullptr, &(state_->fence_))); + VULKAN_CALL(vkCreateFence(*device_, &fence_cinfo, nullptr, &(state_->fence_))); VkCommandBufferBeginInfo cb_begin; cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; @@ -57,8 +56,8 @@ VulkanStream::VulkanStream(const VulkanDevice* device) } VulkanStream::~VulkanStream() { - vkDestroyFence(device_->device, state_->fence_, nullptr); - vkDestroyCommandPool(device_->device, cmd_pool_, nullptr); + vkDestroyFence(*device_, state_->fence_, nullptr); + vkDestroyCommandPool(*device_, cmd_pool_, nullptr); } void VulkanStream::Launch(const std::function& kernel) { @@ -131,20 +130,16 @@ void VulkanStream::Synchronize() { cb_submit.signalSemaphoreCount = 0; cb_submit.pSignalSemaphores = nullptr; - { - // Multiple streams (on different threads) use the same VulkanDevice - // instance, so we need to externally synchronize accesses. - std::lock_guard g(*(device_->queue_mutex)); - VULKAN_CALL(vkQueueSubmit(device_->queue, 1, &cb_submit, state_->fence_)); - } + device_->QueueSubmit(cb_submit, state_->fence_); + uint64_t timeout = 1UL << 30UL; VkResult res; do { - res = vkWaitForFences(device_->device, 1, &(state_->fence_), 0, timeout); + res = vkWaitForFences(*device_, 1, &(state_->fence_), 0, timeout); } while (res == VK_TIMEOUT); VULKAN_CHECK_ERROR(res); VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0)); - VULKAN_CALL(vkResetFences(device_->device, 1, &(state_->fence_))); + VULKAN_CALL(vkResetFences(*device_, 1, &(state_->fence_))); // Re-initialize the command buffer VkCommandBufferBeginInfo cb_begin; diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 51bf5d486b9c..12a5f99ed8e6 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -128,8 +128,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; } } - vkUpdateDescriptorSets(device.device, write_descriptor_sets.size(), - write_descriptor_sets.data(), 0, 0); + vkUpdateDescriptorSets(device, write_descriptor_sets.size(), write_descriptor_sets.data(), 0, + 0); }; const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, device_id](VulkanStreamState* state) { @@ -178,13 +178,13 @@ VulkanModuleNode::~VulkanModuleNode() { if (pe->descriptor_update_template != VK_NULL_HANDLE) { device.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR( - device.device, pe->descriptor_update_template, nullptr); + device, pe->descriptor_update_template, nullptr); } - vkDestroyPipeline(device.device, pe->pipeline, nullptr); - vkDestroyPipelineLayout(device.device, pe->pipeline_layout, nullptr); - vkDestroyDescriptorPool(device.device, pe->descriptor_pool, nullptr); - vkDestroyDescriptorSetLayout(device.device, pe->descriptor_set_layout, nullptr); - vkDestroyShaderModule(device.device, pe->shader, nullptr); + vkDestroyPipeline(device, pe->pipeline, nullptr); + vkDestroyPipelineLayout(device, pe->pipeline_layout, nullptr); + vkDestroyDescriptorPool(device, pe->descriptor_pool, nullptr); + vkDestroyDescriptorSetLayout(device, pe->descriptor_set_layout, nullptr); + vkDestroyShaderModule(device, pe->shader, nullptr); } } } @@ -226,7 +226,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, shader_cinfo.flags = 0; shader_cinfo.codeSize = data.size() * sizeof(uint32_t); shader_cinfo.pCode = data.data(); - VULKAN_CALL(vkCreateShaderModule(device.device, &shader_cinfo, nullptr, &(pe->shader))); + VULKAN_CALL(vkCreateShaderModule(device, &shader_cinfo, nullptr, &(pe->shader))); } std::vector arg_binding; std::vector arg_template; @@ -299,8 +299,8 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, } descrip_cinfo.bindingCount = arg_binding.size(); descrip_cinfo.pBindings = arg_binding.data(); - VULKAN_CALL(vkCreateDescriptorSetLayout(device.device, &descrip_cinfo, nullptr, - &(pe->descriptor_set_layout))); + VULKAN_CALL( + vkCreateDescriptorSetLayout(device, &descrip_cinfo, nullptr, &(pe->descriptor_set_layout))); } if (!device.UseImmediate()) { @@ -311,8 +311,8 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, descrip_pool_cinfo.maxSets = 1; descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size(); descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data(); - VULKAN_CALL(vkCreateDescriptorPool(device.device, &descrip_pool_cinfo, nullptr, - &(pe->descriptor_pool))); + VULKAN_CALL( + vkCreateDescriptorPool(device, &descrip_pool_cinfo, nullptr, &(pe->descriptor_pool))); VkDescriptorSetAllocateInfo alloc_info; alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; @@ -320,7 +320,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, alloc_info.descriptorPool = pe->descriptor_pool; alloc_info.descriptorSetCount = 1; alloc_info.pSetLayouts = &(pe->descriptor_set_layout); - VULKAN_CALL(vkAllocateDescriptorSets(device.device, &alloc_info, &(pe->descriptor_set))); + VULKAN_CALL(vkAllocateDescriptorSets(device, &alloc_info, &(pe->descriptor_set))); } VkPushConstantRange crange; @@ -344,8 +344,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, playout_cinfo.pPushConstantRanges = nullptr; } - VULKAN_CALL( - vkCreatePipelineLayout(device.device, &playout_cinfo, nullptr, &(pe->pipeline_layout))); + VULKAN_CALL(vkCreatePipelineLayout(device, &playout_cinfo, nullptr, &(pe->pipeline_layout))); VkComputePipelineCreateInfo pipeline_cinfo; pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; @@ -361,7 +360,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, pipeline_cinfo.layout = pe->pipeline_layout; pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE; pipeline_cinfo.basePipelineIndex = 0; - VULKAN_CALL(vkCreateComputePipelines(device.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, + VULKAN_CALL(vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe->pipeline))); if (device.UseImmediate()) { @@ -377,7 +376,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, descrip_template_cinfo.pipelineLayout = pe->pipeline_layout; descrip_template_cinfo.set = 0; VULKAN_CALL(device.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR( - device.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template))); + device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template))); } ecache_[device_id][func_name] = pe; return pe; From 6e96f3daaee97aea8f3656d03338070dc95315b6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Jun 2021 13:52:54 -0700 Subject: [PATCH 4/4] [Vulkan][Refactor] Removed the VkPhysicalDeviceProperties member variable from VulkanDevice - Now that there is a separate VulkanDeviceProperties class, the redundant VkPhysicalDeviceProperties can be removed. --- src/runtime/vulkan/vulkan_device.cc | 3 --- src/runtime/vulkan/vulkan_device.h | 2 -- src/runtime/vulkan/vulkan_wrapped_func.cc | 8 +++++++- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index f8c9df1ff061..e92b566e0aab 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -212,8 +212,6 @@ VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2F VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_device) : physical_device_(phy_device) { - vkGetPhysicalDeviceProperties(phy_device, &phy_device_prop); - queue_family_index = SelectComputeQueueFamily(); if (queue_family_index == uint32_t(-1)) { // The GPU doesn't support compute, cannot use @@ -334,7 +332,6 @@ void VulkanDevice::do_swap(VulkanDevice&& other) { std::lock_guard lock_other(other.queue_mutex, std::adopt_lock); std::swap(device_properties, other.device_properties); - std::swap(phy_device_prop, other.phy_device_prop); std::swap(staging_mtype_index, other.staging_mtype_index); std::swap(coherent_staging, other.coherent_staging); std::swap(descriptor_template_khr_functions, other.descriptor_template_khr_functions); diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index a8739cb67af0..b55eb8a3d9e0 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -159,8 +159,6 @@ class VulkanDevice { // Cached device properties, queried through Vulkan API. VulkanDeviceProperties device_properties{}; - // Phyiscal device property - VkPhysicalDeviceProperties phy_device_prop{}; // Memory type index for staging. uint32_t staging_mtype_index{0}; // whether staging is coherent diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 12a5f99ed8e6..86c3ffe23f7d 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -338,7 +338,13 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, if (0 < nbytes_scalars && !pe->use_ubo) { playout_cinfo.pushConstantRangeCount = 1; playout_cinfo.pPushConstantRanges = &crange; - ICHECK_LE(crange.size, device.phy_device_prop.limits.maxPushConstantsSize); + ICHECK_LE(crange.size, device.device_properties.max_push_constants_size) + << "The Vulkan shader uses " << crange.size + << " bytes of push constants, but the device only supports " + << device.device_properties.max_push_constants_size << "bytes. " + << "Please rebuild the shader using a smaller limit on push constants size " + << "by passing -max_push_constants_size=N in the Target string, " + << "or pass -from_device=0 to query all device parameters."; } else { playout_cinfo.pushConstantRangeCount = 0; playout_cinfo.pPushConstantRanges = nullptr;