diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 568672591497..5b630337acbb 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -117,6 +117,7 @@ class VulkanDeviceAPI final : public DeviceAPI { } void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; + std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { const auto& vctx = context(ctx.device_id); @@ -490,33 +491,20 @@ VulkanDeviceAPI::VulkanDeviceAPI() { std::vector all_phy_devs(phy_dev_count); VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs))); for (VkPhysicalDevice phy_dev : all_phy_devs) { - uint32_t queue_prop_count = 0; - vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr); - std::vector queue_props(queue_prop_count); - vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, - dmlc::BeginPtr(queue_props)); - uint32_t queue_family_index = 0; - std::vector queue_create_info; + // Get a list of queue families supporting compute, in order of preference. We currently only + // make use of the most preferred one family. + std::vector queue_family_indexes = GetComputeQueueFamilies(phy_dev); + if (queue_family_indexes.empty()) continue; + uint32_t queue_family_index = queue_family_indexes[0]; float priority = 1.0f; - for (uint32_t i = 0; i < queue_props.size(); i++) { - // find queues that support compute - if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) { - VkDeviceQueueCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.queueFamilyIndex = i; - info.queueCount = 1; - info.pQueuePriorities = &priority; - - queue_create_info.push_back(info); - // only use the first available queue for now - if (queue_create_info.size() == 0) { - queue_family_index = i; - } - } - } - if (queue_create_info.size() == 0) continue; + + struct VkDeviceQueueCreateInfo queue_create_info; + queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + queue_create_info.pNext = nullptr; + queue_create_info.flags = 0; + queue_create_info.queueFamilyIndex = queue_family_index; + queue_create_info.queueCount = 1; + queue_create_info.pQueuePriorities = &priority; VulkanContext ctx; // setup context @@ -554,8 +542,8 @@ VulkanDeviceAPI::VulkanDeviceAPI() { device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; device_create_info.pNext = nullptr; device_create_info.flags = 0; - device_create_info.queueCreateInfoCount = static_cast(queue_create_info.size()); - device_create_info.pQueueCreateInfos = queue_create_info.data(); + device_create_info.queueCreateInfoCount = 1; + device_create_info.pQueueCreateInfos = &queue_create_info; device_create_info.enabledLayerCount = 0; device_create_info.ppEnabledLayerNames = nullptr; device_create_info.enabledExtensionCount = extensions.size(); @@ -677,7 +665,34 @@ VulkanDeviceAPI::VulkanDeviceAPI() { << "\' phy_dev_id=" << context_[i].phy_device << " use_immediate=" << context_[i].UseImmediate(); } -} // namespace vulkan +} + +std::vector VulkanDeviceAPI::GetComputeQueueFamilies(VkPhysicalDevice phy_dev) { + uint32_t queue_prop_count = 0; + vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr); + std::vector queue_props(queue_prop_count); + vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props)); + + std::vector result; + // Prefer compute-only queues. On cerain devices supporting this (e.g. Mesa RADV), using + // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop). + for (uint32_t i = 0; i != queue_prop_count; ++i) { + if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && + (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) { + result.push_back(i); + } + } + // Now, push the compute queues that we skipped above into the list. + for (uint32_t i = 0; i != queue_prop_count; ++i) { + if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && + (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) { + result.push_back(i); + } + } + return result; +} + +// namespace vulkan class VulkanModuleNode; // a wrapped function class to get packed func.