diff --git a/cmake/modules/Vulkan.cmake b/cmake/modules/Vulkan.cmake index 095790f08547..4dc9bd664d8a 100644 --- a/cmake/modules/Vulkan.cmake +++ b/cmake/modules/Vulkan.cmake @@ -18,14 +18,6 @@ # Be compatible with older version of CMake find_vulkan(${USE_VULKAN}) -# Extra Vulkan runtime options, exposed for advanced users. -tvm_option(USE_VULKAN_IMMEDIATE_MODE "Use Vulkan Immediate mode -(KHR_push_descriptor extension)" ON IF USE_VULKAN) -tvm_option(USE_VULKAN_DEDICATED_ALLOCATION "Use Vulkan dedicated allocations" ON -IF USE_VULKAN) -tvm_option(USE_VULKAN_VALIDATION "Enable Vulkan API validation layers" OFF - IF USE_VULKAN) - if(USE_VULKAN) if(NOT Vulkan_FOUND) message(FATAL_ERROR "Cannot find Vulkan, USE_VULKAN=" ${USE_VULKAN}) @@ -38,17 +30,4 @@ if(USE_VULKAN) list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS}) list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY}) - - if(USE_VULKAN_IMMEDIATE_MODE) - message(STATUS "Build with Vulkan immediate mode") - add_definitions(-DUSE_VULKAN_IMMEDIATE_MODE=1) - endif() - if(USE_VULKAN_DEDICATED_ALLOCATION) - message(STATUS "Build with Vulkan dedicated allocation") - add_definitions(-DUSE_VULKAN_DEDICATED_ALLOCATION=1) - endif() - if(USE_VULKAN_VALIDATION) - message(STATUS "Build with Vulkan API validation") - add_definitions(-DUSE_VULKAN_VALIDATION=1) - endif() endif(USE_VULKAN) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 4eda5e8cc332..efea47752f6d 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -376,7 +376,7 @@ def api_version(self): The version of the SDK """ - return self._GetDeviceAttr(self.device_type, self.device_id, 12) + return self._GetDeviceAttr(self.device_type, self.device_id, 11) @property def driver_version(self): diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index b7fe2b1ceb21..8982ea32648b 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -21,11 +21,13 @@ #include #include #include +#include #include #include #include #include +#include #include #include "../file_utils.h" @@ -262,6 +264,10 @@ class VulkanDeviceAPI final : public DeviceAPI { delete pbuf; } + Target GetDeviceDescription(VkInstance instance, VkPhysicalDevice dev, + const std::vector& instance_extensions, + const std::vector& device_extensions); + protected: void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, @@ -426,12 +432,216 @@ class VulkanDeviceAPI final : public DeviceAPI { return context_[device_id]; } + Target GenerateTarget(size_t device_id) const { return context(device_id).target; } + private: + std::vector find_enabled_extensions( + const std::vector& ext_prop, + const std::vector& required_extensions, + const std::vector& optional_extensions) { + std::set available_extensions; + for (const auto& prop : ext_prop) { + if (prop.specVersion > 0) { + available_extensions.insert(prop.extensionName); + } + } + + std::vector enabled_extensions; + for (const auto& ext : required_extensions) { + ICHECK(available_extensions.count(ext)) + << "Required vulkan extension \"" << ext << "\" not supported by driver"; + enabled_extensions.push_back(ext); + } + + for (const auto& ext : optional_extensions) { + if (available_extensions.count(ext)) { + enabled_extensions.push_back(ext); + } + } + + return enabled_extensions; + } + VkInstance instance_{nullptr}; // The physical devices, have 1 to 1 mapping to devices std::vector context_; }; +Target VulkanDeviceAPI::GetDeviceDescription(VkInstance instance, VkPhysicalDevice dev, + const std::vector& instance_extensions, + const std::vector& device_extensions) { + auto has_extension = [&](const char* query) { + return std::any_of(device_extensions.begin(), device_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }) || + std::any_of(instance_extensions.begin(), instance_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); + }; + + // Declare output locations for properties + VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2}; + VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES}; + VkPhysicalDeviceSubgroupProperties subgroup = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES}; + + // Need to do initial query in order to check the apiVersion. + vkGetPhysicalDeviceProperties(dev, &properties.properties); + + // Set up linked list for property query + { + void** pp_next = &properties.pNext; + if (has_extension("VK_KHR_driver_properties")) { + *pp_next = &driver; + pp_next = &driver.pNext; + } + if (properties.properties.apiVersion >= VK_API_VERSION_1_1) { + *pp_next = &subgroup; + pp_next = &subgroup.pNext; + } + } + + // Declare output locations for features + VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + VkPhysicalDevice8BitStorageFeatures storage_8bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; + VkPhysicalDevice16BitStorageFeatures storage_16bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; + VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; + + // Set up linked list for feature query + { + void** pp_next = &features.pNext; + if (has_extension("VK_KHR_8bit_storage")) { + *pp_next = &storage_8bit; + pp_next = &storage_8bit.pNext; + } + if (has_extension("VK_KHR_16bit_storage")) { + *pp_next = &storage_16bit; + pp_next = &storage_16bit.pNext; + } + if (has_extension("VK_KHR_shader_float16_int8")) { + *pp_next = &float16_int8; + pp_next = &float16_int8.pNext; + } + } + + if (has_extension("VK_KHR_get_physical_device_properties2")) { + // Preferred method, call to get all properties that can be queried. + auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); + vkGetPhysicalDeviceProperties2KHR(dev, &properties); + + auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); + vkGetPhysicalDeviceFeatures2KHR(dev, &features); + } else { + // Fallback, get as many features as we can from the Vulkan1.0 + // API. Corresponding vkGetPhysicalDeviceProperties was already done earlier. + vkGetPhysicalDeviceFeatures(dev, &features.features); + } + + //// Now, extracting all the information from the vulkan query. + + // Not technically needed, because VK_SHADER_STAGE_COMPUTE_BIT will + // be set so long at least one queue has VK_QUEUE_COMPUTE_BIT, but + // preferring the explicit check. + uint32_t supported_subgroup_operations = + (subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0; + + // Even if we can't query it, warp size must be at least 1. Must + // also be defined, as `transpose` operation requires it. + uint32_t thread_warp_size = std::max(subgroup.subgroupSize, 1U); + + // By default, use the maximum API version that the driver allows, + // so that any supported features can be used by TVM shaders. + // However, if we can query the conformance version, then limit to + // only using the api version that passes the vulkan conformance + // tests. + uint32_t vulkan_api_version = properties.properties.apiVersion; + if (has_extension("VK_KHR_driver_properties")) { + auto api_major = VK_VERSION_MAJOR(vulkan_api_version); + auto api_minor = VK_VERSION_MINOR(vulkan_api_version); + if ((api_major > driver.conformanceVersion.major) || + ((api_major == driver.conformanceVersion.major) && + (api_minor > driver.conformanceVersion.minor))) { + vulkan_api_version = + VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0); + } + } + + // From "Versions and Formats" section of Vulkan spec. + uint32_t max_spirv_version = 0x10000; + if (vulkan_api_version >= VK_API_VERSION_1_2) { + max_spirv_version = 0x10500; + } else if (has_extension("VK_KHR_spirv_1_4")) { + max_spirv_version = 0x10400; + } else if (vulkan_api_version >= VK_API_VERSION_1_1) { + max_spirv_version = 0x10300; + } + + // Support is available based on these extensions, but allow it to + // be disabled based on an environment variable. + bool supports_push_descriptor = + has_extension("VK_KHR_push_descriptor") && has_extension("VK_KHR_descriptor_update_template"); + { + const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR"); + if (disable && *disable) { + supports_push_descriptor = false; + } + } + + // Support is available based on these extensions, but allow it to + // be disabled based on an environment variable. + bool supports_dedicated_allocation = has_extension("VK_KHR_get_memory_requirements2") && + has_extension("VK_KHR_dedicated_allocation"); + { + const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); + if (disable && *disable) { + supports_dedicated_allocation = false; + } + } + + Map config = { + {"kind", String("vulkan")}, + // Feature support + {"supports_float16", Bool(float16_int8.shaderFloat16)}, + {"supports_float32", Bool(true)}, + {"supports_float64", Bool(features.features.shaderFloat64)}, + {"supports_int8", Bool(float16_int8.shaderInt8)}, + {"supports_int16", Bool(features.features.shaderInt16)}, + {"supports_int32", Bool(true)}, + {"supports_int64", Bool(features.features.shaderInt64)}, + {"supports_8bit_buffer", Bool(storage_8bit.storageBuffer8BitAccess)}, + {"supports_16bit_buffer", Bool(storage_16bit.storageBuffer16BitAccess)}, + {"supports_storage_buffer_storage_class", + Bool(has_extension("VK_KHR_storage_buffer_storage_class"))}, + {"supports_push_descriptor", Bool(supports_push_descriptor)}, + {"supports_dedicated_allocation", Bool(supports_dedicated_allocation)}, + {"supported_subgroup_operations", Integer(supported_subgroup_operations)}, + // Physical device limits + {"max_num_threads", Integer(properties.properties.limits.maxComputeWorkGroupInvocations)}, + {"thread_warp_size", Integer(thread_warp_size)}, + {"max_block_size_x", Integer(properties.properties.limits.maxComputeWorkGroupSize[0])}, + {"max_block_size_y", Integer(properties.properties.limits.maxComputeWorkGroupSize[1])}, + {"max_block_size_z", Integer(properties.properties.limits.maxComputeWorkGroupSize[2])}, + {"max_push_constants_size", Integer(properties.properties.limits.maxPushConstantsSize)}, + {"max_uniform_buffer_range", Integer(properties.properties.limits.maxUniformBufferRange)}, + {"max_storage_buffer_range", + Integer(IntImm(DataType::UInt(32), properties.properties.limits.maxStorageBufferRange))}, + {"max_per_stage_descriptor_storage_buffer", + Integer(properties.properties.limits.maxPerStageDescriptorStorageBuffers)}, + {"max_shared_memory_per_block", + Integer(properties.properties.limits.maxComputeSharedMemorySize)}, + // Other device properties + {"device_name", String(properties.properties.deviceName)}, + {"driver_version", Integer(properties.properties.driverVersion)}, + {"vulkan_api_version", Integer(vulkan_api_version)}, + {"max_spirv_version", Integer(max_spirv_version)}, + }; + + return Target(config); +} + void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { size_t index = static_cast(dev.device_id); if (kind == kExist) { @@ -439,39 +649,24 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) return; } ICHECK_LT(index, context_.size()) << "Invalid device id " << index; - const auto& vctx = context(index); - VkPhysicalDeviceProperties phy_prop; - vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); + + const auto& target = context(index).target; switch (kind) { case kMaxThreadsPerBlock: { - int64_t value = phy_prop.limits.maxComputeWorkGroupInvocations; - *rv = value; + *rv = target->GetAttr("max_num_threads").value(); break; } case kMaxSharedMemoryPerBlock: { - int64_t value = phy_prop.limits.maxComputeSharedMemorySize; - *rv = value; + *rv = target->GetAttr("max_shared_memory_per_block"); break; } case kWarpSize: { - VkPhysicalDeviceSubgroupProperties subgroup_prop; - subgroup_prop.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; - subgroup_prop.pNext = NULL; - - VkPhysicalDeviceProperties2 phy_prop2; - phy_prop2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; - phy_prop2.pNext = &subgroup_prop; - - vkGetPhysicalDeviceProperties2(vctx.phy_device, &phy_prop2); - int64_t subgroup_size = subgroup_prop.subgroupSize; - ICHECK(subgroup_size >= 1); - - *rv = subgroup_size; + *rv = target->GetAttr("thread_warp_size").value(); break; } case kComputeVersion: { - int64_t value = phy_prop.apiVersion; + int64_t value = target->GetAttr("vulkan_api_version").value(); std::ostringstream os; os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." << VK_VERSION_PATCH(value); @@ -479,33 +674,39 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) break; } case kDeviceName: - *rv = std::string(phy_prop.deviceName); + *rv = target->GetAttr("device_name").value(); break; + case kMaxClockRate: break; + case kMultiProcessorCount: break; + case kExist: break; + case kMaxThreadDimensions: { - int64_t dims[3]; - dims[0] = phy_prop.limits.maxComputeWorkGroupSize[0]; - dims[1] = phy_prop.limits.maxComputeWorkGroupSize[1]; - dims[2] = phy_prop.limits.maxComputeWorkGroupSize[2]; std::stringstream ss; // use json string to return multiple int values; - ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; + ss << "[" << target->GetAttr("max_block_size_x").value() << ", " + << target->GetAttr("max_block_size_y").value() << ", " + << target->GetAttr("max_block_size_z").value() << "]"; *rv = ss.str(); break; } + case kMaxRegistersPerBlock: break; + case kGcnArch: break; + case kApiVersion: *rv = VK_HEADER_VERSION; break; + case kDriverVersion: { - int64_t value = phy_prop.driverVersion; + int64_t value = target->GetAttr("driver_version").value(); std::ostringstream os; os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." << VK_VERSION_PATCH(value); @@ -516,67 +717,86 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) } VulkanDeviceAPI::VulkanDeviceAPI() { - VkApplicationInfo app_info; - app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - app_info.pNext = nullptr; - app_info.pApplicationName = "TVM"; - app_info.applicationVersion = 0; - app_info.pEngineName = ""; - app_info.engineVersion = 0; - app_info.apiVersion = VK_MAKE_VERSION(1, 0, 0); - - VkInstanceCreateInfo inst_info; - inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - inst_info.pNext = nullptr; - inst_info.flags = 0; - const auto layers = []() -> std::vector { uint32_t inst_layer_prop_count; VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr)); std::vector inst_layer_prop(inst_layer_prop_count); VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data())); std::vector l; - for (const auto& lp : inst_layer_prop) { - // TODO(tulloch): add CMAKE options. - (void)lp; // suppress unused variable warning. -#ifdef USE_VULKAN_VALIDATION - if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) { - l.push_back("VK_LAYER_LUNARG_standard_validation"); - } - if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) { - l.push_back("VK_LAYER_LUNARG_parameter_validation"); - } - if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) { - l.push_back("VK_LAYER_KHRONOS_validation"); + + const char* enable = std::getenv("TVM_VULKAN_ENABLE_VALIDATION_LAYERS"); + bool validation_enabled = enable && *enable; + if (validation_enabled) { + for (const auto& lp : inst_layer_prop) { + if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) { + l.push_back("VK_LAYER_LUNARG_standard_validation"); + } + if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) { + l.push_back("VK_LAYER_LUNARG_parameter_validation"); + } + if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) { + l.push_back("VK_LAYER_KHRONOS_validation"); + } } -#endif } return l; }(); - const auto instance_extensions = []() -> std::vector { + const auto instance_extensions = [this]() { + std::vector required_extensions{}; + std::vector optional_extensions{"VK_KHR_get_physical_device_properties2"}; + uint32_t inst_extension_prop_count; VULKAN_CALL( vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr)); std::vector inst_extension_prop(inst_extension_prop_count); VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, inst_extension_prop.data())); - std::vector extensions; - for (const auto& ip : inst_extension_prop) { - if (std::strcmp(ip.extensionName, "VK_KHR_get_physical_device_properties2") == 0) { - extensions.push_back("VK_KHR_get_physical_device_properties2"); - } - } - return extensions; + + return find_enabled_extensions(inst_extension_prop, required_extensions, optional_extensions); }(); - inst_info.pApplicationInfo = &app_info; - inst_info.enabledLayerCount = layers.size(); - inst_info.ppEnabledLayerNames = layers.data(); - inst_info.enabledExtensionCount = instance_extensions.size(); - inst_info.ppEnabledExtensionNames = instance_extensions.data(); + auto has_instance_extension = [&instance_extensions](const char* query) { + return std::any_of(instance_extensions.begin(), instance_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); + }; + + const auto instance_api_version = []() { + uint32_t api_version = VK_MAKE_VERSION(1, 0, 0); + + // Result from vkGetInstanceProcAddr is NULL if driver only + // supports vulkan 1.0. + auto vkEnumerateInstanceVersion = + (PFN_vkEnumerateInstanceVersion)vkGetInstanceProcAddr(NULL, "vkEnumerateInstanceVersion"); + if (vkEnumerateInstanceVersion) { + vkEnumerateInstanceVersion(&api_version); + } + + return api_version; + }(); - VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_)); + { + VkApplicationInfo app_info; + app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + app_info.pNext = nullptr; + app_info.pApplicationName = "TVM"; + app_info.applicationVersion = 0; + app_info.pEngineName = ""; + app_info.engineVersion = 0; + app_info.apiVersion = instance_api_version; + + VkInstanceCreateInfo inst_info; + inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + inst_info.pNext = nullptr; + inst_info.flags = 0; + inst_info.pApplicationInfo = &app_info; + inst_info.enabledLayerCount = layers.size(); + inst_info.ppEnabledLayerNames = layers.data(); + inst_info.enabledExtensionCount = instance_extensions.size(); + inst_info.ppEnabledExtensionNames = instance_extensions.data(); + + VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_)); + } uint32_t phy_dev_count = 0; VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr)); @@ -603,51 +823,102 @@ VulkanDeviceAPI::VulkanDeviceAPI() { ctx.phy_device = phy_dev; vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop)); - const auto extensions = [&]() { + const auto device_extensions = [&]() { + std::vector required_extensions{}; + std::vector optional_extensions{ + "VK_KHR_driver_properties", + "VK_KHR_storage_buffer_storage_class", + "VK_KHR_8bit_storage", + "VK_KHR_16bit_storage", + "VK_KHR_shader_float16_int8", + "VK_KHR_push_descriptor", + "VK_KHR_descriptor_update_template", + "VK_KHR_get_memory_requirements2", + "VK_KHR_dedicated_allocation", + "VK_KHR_spirv_1_4", + }; + uint32_t device_extension_prop_count; VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr, &device_extension_prop_count, nullptr)); std::vector device_extension_prop(device_extension_prop_count); VULKAN_CALL(vkEnumerateDeviceExtensionProperties( ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data())); - std::vector extensions; - for (const auto& dp : device_extension_prop) { - if ((std::strcmp(dp.extensionName, "VK_KHR_push_descriptor") == 0) && dp.specVersion > 0) { - extensions.push_back("VK_KHR_push_descriptor"); - } - if ((std::strcmp(dp.extensionName, "VK_KHR_descriptor_update_template") == 0) && - dp.specVersion > 0) { - extensions.push_back("VK_KHR_descriptor_update_template"); - } - if ((std::strcmp(dp.extensionName, "VK_KHR_get_memory_requirements2") == 0) && - dp.specVersion > 0) { - extensions.push_back("VK_KHR_get_memory_requirements2"); - } - if ((std::strcmp(dp.extensionName, "VK_KHR_dedicated_allocation") == 0) && - dp.specVersion > 0) { - extensions.push_back("VK_KHR_dedicated_allocation"); - } - } - return extensions; + + return find_enabled_extensions(device_extension_prop, required_extensions, + optional_extensions); }(); - // All TVM-generated spirv shaders are marked as requiring int64 - // support, so we need to request it from the device, too. - VkPhysicalDeviceFeatures enabled_features = {}; - enabled_features.shaderInt64 = VK_TRUE; - - VkDeviceCreateInfo device_create_info; - device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; - device_create_info.pNext = nullptr; - device_create_info.flags = 0; - device_create_info.queueCreateInfoCount = 1; - device_create_info.pQueueCreateInfos = &queue_create_info; - device_create_info.enabledLayerCount = 0; - device_create_info.ppEnabledLayerNames = nullptr; - device_create_info.enabledExtensionCount = extensions.size(); - device_create_info.ppEnabledExtensionNames = extensions.data(); - device_create_info.pEnabledFeatures = &enabled_features; - VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device))); + ctx.target = GetDeviceDescription(instance_, phy_dev, instance_extensions, device_extensions); + + { + // Enable all features we may use that a device supports. + VkPhysicalDeviceFeatures2 enabled_features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + VkPhysicalDevice8BitStorageFeatures storage_8bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; + VkPhysicalDevice16BitStorageFeatures storage_16bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; + VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; + + void** pp_next = &enabled_features.pNext; + bool needs_float16_int8 = false; + + auto has_support = [&](const char* name) { return ctx.target->GetAttr(name).value(); }; + if (has_support("supports_float16")) { + float16_int8.shaderFloat16 = true; + needs_float16_int8 = true; + } + if (has_support("supports_float64")) { + enabled_features.features.shaderFloat64 = true; + } + if (has_support("supports_int8")) { + float16_int8.shaderInt8 = true; + needs_float16_int8 = true; + } + if (has_support("supports_int16")) { + enabled_features.features.shaderInt16 = true; + } + if (has_support("supports_int64")) { + enabled_features.features.shaderInt64 = true; + } + if (has_support("supports_8bit_buffer")) { + storage_8bit.storageBuffer8BitAccess = true; + *pp_next = &storage_8bit; + pp_next = &storage_8bit.pNext; + } + if (has_support("supports_16bit_buffer")) { + storage_16bit.storageBuffer16BitAccess = true; + *pp_next = &storage_16bit; + pp_next = &storage_16bit.pNext; + } + + if (needs_float16_int8) { + *pp_next = &float16_int8; + pp_next = &float16_int8.pNext; + } + + VkDeviceCreateInfo device_create_info; + device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + device_create_info.pNext = nullptr; + device_create_info.flags = 0; + device_create_info.queueCreateInfoCount = 1; + device_create_info.pQueueCreateInfos = &queue_create_info; + device_create_info.enabledLayerCount = 0; + device_create_info.ppEnabledLayerNames = nullptr; + device_create_info.enabledExtensionCount = device_extensions.size(); + device_create_info.ppEnabledExtensionNames = device_extensions.data(); + + if (has_instance_extension("VK_KHR_get_physical_device_properties2")) { + device_create_info.pEnabledFeatures = nullptr; + device_create_info.pNext = &enabled_features; + } else { + device_create_info.pNext = nullptr; + device_create_info.pEnabledFeatures = &enabled_features.features; + } + VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device))); + } + ctx.queue_mutex.reset(new std::mutex()); vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue)); ctx.queue_family_index = queue_family_index; @@ -718,42 +989,17 @@ VulkanDeviceAPI::VulkanDeviceAPI() { } } ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; - auto has_extension = [&extensions](const char* query) { - return std::any_of(extensions.begin(), extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; -#ifdef USE_VULKAN_IMMEDIATE_MODE - if (has_extension("VK_KHR_push_descriptor") && - has_extension("VK_KHR_descriptor_update_template")) { - ctx.descriptor_template_khr_functions = std::unique_ptr( - new VulkanDescriptorTemplateKHRFunctions()); - ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR = - CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr( - ctx.device, "vkCreateDescriptorUpdateTemplateKHR")); - ctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR = - CHECK_NOTNULL((PFN_vkDestroyDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr( - ctx.device, "vkDestroyDescriptorUpdateTemplateKHR")); - ctx.descriptor_template_khr_functions->vkUpdateDescriptorSetWithTemplateKHR = - CHECK_NOTNULL((PFN_vkUpdateDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr( - ctx.device, "vkUpdateDescriptorSetWithTemplateKHR")); - ctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR = - CHECK_NOTNULL((PFN_vkCmdPushDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr( - ctx.device, "vkCmdPushDescriptorSetWithTemplateKHR")); + if (ctx.target->GetAttr("supports_push_descriptor").value()) { + ctx.descriptor_template_khr_functions = + std::make_unique(ctx.device); } -#endif -#ifdef USE_VULKAN_DEDICATED_ALLOCATION - if (has_extension("VK_KHR_get_memory_requirements2") && - has_extension("VK_KHR_dedicated_allocation")) { + if (ctx.target->GetAttr("supports_dedicated_allocation").value()) { ctx.get_buffer_memory_requirements_2_functions = - std::unique_ptr( - new VulkanGetBufferMemoryRequirements2Functions()); - ctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR = - CHECK_NOTNULL((PFN_vkGetBufferMemoryRequirements2KHR)vkGetDeviceProcAddr( - ctx.device, "vkGetBufferMemoryRequirements2KHR")); + std::make_unique(ctx.device); } -#endif + context_.push_back(std::move(ctx)); } @@ -1335,6 +1581,10 @@ TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* *rv = static_cast(ptr); }); +TVM_REGISTER_GLOBAL("device_api.vulkan.generate_target").set_body_typed([](int device_id) { + return VulkanDeviceAPI::Global()->GenerateTarget(device_id); +}); + } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 2ef879a487a6..14ecdba6ca40 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -87,10 +88,10 @@ inline const char* VKGetErrorString(VkResult error) { * \brief Protected Vulkan call * \param func Expression to call. */ -#define VULKAN_CHECK_ERROR(__e) \ - { \ - ICHECK(__e == VK_SUCCESS) << "Vulan Error, code=" << __e << ": " \ - << vulkan::VKGetErrorString(__e); \ +#define VULKAN_CHECK_ERROR(__e) \ + { \ + ICHECK(__e == VK_SUCCESS) << "Vulkan Error, code=" << __e << ": " \ + << vulkan::VKGetErrorString(__e); \ } #define VULKAN_CALL(func) \ @@ -100,6 +101,18 @@ inline const char* VKGetErrorString(VkResult error) { } struct VulkanDescriptorTemplateKHRFunctions { + explicit VulkanDescriptorTemplateKHRFunctions(VkDevice device) { + vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR")); + vkDestroyDescriptorUpdateTemplateKHR = (PFN_vkDestroyDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkDestroyDescriptorUpdateTemplateKHR")); + vkUpdateDescriptorSetWithTemplateKHR = (PFN_vkUpdateDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkUpdateDescriptorSetWithTemplateKHR")); + vkCmdPushDescriptorSetWithTemplateKHR = + (PFN_vkCmdPushDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkCmdPushDescriptorSetWithTemplateKHR")); + } + PFN_vkCreateDescriptorUpdateTemplateKHR vkCreateDescriptorUpdateTemplateKHR{nullptr}; PFN_vkDestroyDescriptorUpdateTemplateKHR vkDestroyDescriptorUpdateTemplateKHR{nullptr}; PFN_vkUpdateDescriptorSetWithTemplateKHR vkUpdateDescriptorSetWithTemplateKHR{nullptr}; @@ -107,14 +120,22 @@ struct VulkanDescriptorTemplateKHRFunctions { }; struct VulkanGetBufferMemoryRequirements2Functions { + explicit VulkanGetBufferMemoryRequirements2Functions(VkDevice device) { + vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2KHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); + } + PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr}; }; struct VulkanContext { - // phyiscal device + // physical device VkPhysicalDevice phy_device{nullptr}; + // Phyiscal device property VkPhysicalDeviceProperties phy_device_prop; + // Target that best represents this physical device + Target target; // Memory type index for staging. uint32_t staging_mtype_index{0}; // whether staging is coherent @@ -136,7 +157,7 @@ struct VulkanContext { // Queue family index. VkQueueFamilyProperties queue_prop; - bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; } + bool UseImmediate() const { return descriptor_template_khr_functions != nullptr; } }; } // namespace vulkan diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index ff9aee406574..e06bde08895d 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -35,17 +35,45 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { ICHECK_EQ(target->kind->device_type, kDLVulkan) << "SPIRVSupport can only be checked for vulkan device type"; - // Currently, this codifies the assumptions that were present and - // implicit in previous implementations. In the future, this will - // pull information from the specified `Target`. - - supports_storage_buffer_storage_class = (SPV_VERSION >= 0x10300); - supports_storage_buffer_8bit_access = true; - supports_storage_buffer_16bit_access = true; - supports_float16 = true; - supports_int8 = true; - supports_int16 = true; - supports_int64 = true; + if (target->GetAttr("supported_subgroup_operations")) { + supported_subgroup_operations = + target->GetAttr("supported_subgroup_operations").value(); + } + if (target->GetAttr("max_push_constants_size")) { + max_push_constants_size = target->GetAttr("max_push_constants_size").value(); + } + if (target->GetAttr("max_uniform_buffer_range")) { + max_uniform_buffer_range = target->GetAttr("max_uniform_buffer_range").value(); + } + if (target->GetAttr("max_storage_buffer_range")) { + max_storage_buffer_range = target->GetAttr("max_storage_buffer_range").value(); + } + if (target->GetAttr("max_per_stage_descriptor_storage_buffer")) { + max_per_stage_descriptor_storage_buffers = + target->GetAttr("max_per_stage_descriptor_storage_buffer").value(); + } + if (target->GetAttr("supports_storage_buffer_storage_class")) { + supports_storage_buffer_storage_class = + target->GetAttr("supports_storage_buffer_storage_class").value(); + } + if (target->GetAttr("supports_8bit_buffer")) { + supports_storage_buffer_8bit_access = target->GetAttr("supports_8bit_buffer").value(); + } + if (target->GetAttr("supports_16bit_buffer")) { + supports_storage_buffer_16bit_access = target->GetAttr("supports_16bit_buffer").value(); + } + if (target->GetAttr("supports_float16")) { + supports_float16 = target->GetAttr("supports_float16").value(); + } + if (target->GetAttr("supports_int8")) { + supports_int8 = target->GetAttr("supports_int8").value(); + } + if (target->GetAttr("supports_int16")) { + supports_int16 = target->GetAttr("supports_int16").value(); + } + if (target->GetAttr("supports_int64")) { + supports_int64 = target->GetAttr("supports_int64").value(); + } } } // namespace codegen diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index e06b2c05d3bf..08e998e0f035 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -209,6 +209,45 @@ Map UpdateROCmAttrs(Map attrs) { return attrs; } +/*! + * \brief Update the attributes in the Vulkan target. + * \param attrs The original attributes + * \return The updated attributes + */ +Map UpdateVulkanAttrs(Map attrs) { + if (attrs.count("from_device")) { + int device_id = Downcast(attrs.at("from_device")); + const PackedFunc* generate_target = runtime::Registry::Get("device_api.vulkan.generate_target"); + ICHECK(generate_target) + << "Requested to read Vulkan parameters from device, but no Vulkan runtime available"; + Target target = (*generate_target)(device_id).AsObjectRef(); + for (auto& kv : target->Export()) { + if (!attrs.count(kv.first)) { + attrs.Set(kv.first, kv.second); + } + } + + attrs.erase("from_device"); + } + + // Set defaults here, rather than in the .add_attr_option() calls. + // The priority should be user-specified > device-query > default, + // but defaults defined in .add_attr_option() are already applied by + // this point. Longer-term, would be good to add a + // `DeviceAPI::GenerateTarget` function and extend "from_device" to + // work for all runtimes. + std::unordered_map defaults = {{"supports_float32", Bool(true)}, + {"supports_int32", Bool(true)}, + {"max_num_threads", Integer(256)}, + {"thread_warp_size", Integer(1)}}; + for (const auto& kv : defaults) { + if (!attrs.count(kv.first)) { + attrs.Set(kv.first, kv.second); + } + } + return attrs; +} + /********** Register Target kinds and attributes **********/ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) @@ -273,9 +312,40 @@ TVM_REGISTER_TARGET_KIND("metal", kDLMetal) TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("system-lib") - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .set_default_keys({"vulkan", "gpu"}); + .add_attr_option("from_device") + // Feature support + .add_attr_option("supports_float16") + .add_attr_option("supports_float32") + .add_attr_option("supports_float64") + .add_attr_option("supports_int8") + .add_attr_option("supports_int16") + .add_attr_option("supports_int32") + .add_attr_option("supports_int64") + .add_attr_option("supports_8bit_buffer") + .add_attr_option("supports_16bit_buffer") + .add_attr_option("supports_storage_buffer_storage_class") + .add_attr_option("supports_push_descriptor") + .add_attr_option("supports_dedicated_allocation") + .add_attr_option("supported_subgroup_operations") + // Physical device limits + .add_attr_option("max_num_threads") + .add_attr_option("thread_warp_size") + .add_attr_option("max_block_size_x") + .add_attr_option("max_block_size_y") + .add_attr_option("max_block_size_z") + .add_attr_option("max_push_constants_size") + .add_attr_option("max_uniform_buffer_range") + .add_attr_option("max_storage_buffer_range") + .add_attr_option("max_per_stage_descriptor_storage_buffer") + .add_attr_option("max_shared_memory_per_block") + // Other device properties + .add_attr_option("device_name") + .add_attr_option("driver_version") + .add_attr_option("vulkan_api_version") + .add_attr_option("max_spirv_version") + // Tags + .set_default_keys({"vulkan", "gpu"}) + .set_attrs_preprocessor(UpdateVulkanAttrs); TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) .add_attr_option("system-lib")