diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 22c2893141cf..f25e581dbf24 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -106,6 +106,29 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target auto target_device = target->GetAttr("device", ""); LOG(FATAL) << "No default hardware parameters for opencl target device: " << target_device; } + } else if (device_type == kDLVulkan) { + auto ctx = TVMContext{static_cast(device_type), 0}; + auto device_name = "device_api.vulkan"; + auto func = tvm::runtime::Registry::Get(device_name); + ICHECK(func != nullptr) << "Cannot find Vulkan device_api in registry"; + auto device_api = static_cast(((*func)()).operator void*()); + + tvm::runtime::TVMRetValue ret; + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); + int max_shared_memory_per_block = ret; + + int max_local_memory_per_block = INT32_MAX; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); + int max_threads_per_block = ret; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); + int warp_size = ret; + + int max_vthread_extent = std::max(1, warp_size / 4); + + return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, + max_threads_per_block, max_vthread_extent, warp_size); } else { LOG(FATAL) << "No default hardware parameters for target: " << target; } diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 794f3c570f96..ff1b82f930d7 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -367,28 +367,37 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* } ICHECK_LT(index, context_.size()) << "Invalid device id " << index; const auto& vctx = context(index); + VkPhysicalDeviceProperties phy_prop; + vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); + switch (kind) { case kMaxThreadsPerBlock: { - VkPhysicalDeviceProperties phy_prop; - vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); int64_t value = phy_prop.limits.maxComputeWorkGroupInvocations; *rv = value; break; } case kMaxSharedMemoryPerBlock: { - VkPhysicalDeviceProperties phy_prop; - vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); int64_t value = phy_prop.limits.maxComputeSharedMemorySize; *rv = value; break; } case kWarpSize: { - *rv = 1; + VkPhysicalDeviceSubgroupProperties subgroup_prop; + subgroup_prop.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; + subgroup_prop.pNext = NULL; + + VkPhysicalDeviceProperties2 phy_prop2; + phy_prop2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; + phy_prop2.pNext = &subgroup_prop; + + vkGetPhysicalDeviceProperties2(vctx.phy_device, &phy_prop2); + int64_t subgroup_size = subgroup_prop.subgroupSize; + ICHECK(subgroup_size >= 1); + + *rv = subgroup_size; break; } case kComputeVersion: { - VkPhysicalDeviceProperties phy_prop; - vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); int64_t value = phy_prop.apiVersion; std::ostringstream os; os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." @@ -405,8 +414,6 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* case kExist: break; case kMaxThreadDimensions: { - VkPhysicalDeviceProperties phy_prop; - vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); int64_t dims[3]; dims[0] = phy_prop.limits.maxComputeWorkGroupSize[0]; dims[1] = phy_prop.limits.maxComputeWorkGroupSize[1];