Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/target/spirv/build_vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,24 @@ namespace codegen {

class SPIRVTools {
public:
SPIRVTools() { ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); }
explicit SPIRVTools(Target target) {
uint32_t vulkan_version =
target->GetAttr<Integer>("vulkan_api_version").value_or(VK_API_VERSION_1_0);
uint32_t spirv_version = target->GetAttr<Integer>("max_spirv_version").value_or(0x10000);

spv_target_env validation_version;
if (vulkan_version >= VK_API_VERSION_1_2) {
validation_version = SPV_ENV_VULKAN_1_2;
} else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) {
validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4;
} else if (vulkan_version >= VK_API_VERSION_1_1) {
validation_version = SPV_ENV_VULKAN_1_1;
} else {
validation_version = SPV_ENV_VULKAN_1_0;
}

ctx_ = spvContextCreate(validation_version);
}
~SPIRVTools() { spvContextDestroy(ctx_); }
std::string BinaryToText(const std::vector<uint32_t>& bin) {
spv_text text = nullptr;
Expand Down Expand Up @@ -80,7 +97,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction)
using tvm::runtime::VulkanShader;

std::ostringstream code_data;
static SPIRVTools spirv_tools;
SPIRVTools spirv_tools(target);
std::unordered_map<std::string, VulkanShader> smap;

const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc");
Expand Down
29 changes: 18 additions & 11 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,27 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& ext
spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
spirv::Value value;
if (sync == "warp") {
return value;
} else if (sync == "shared") {
auto type_int = builder_->GetSType(DataType::Int(32));
builder_->MakeInst(
spv::OpControlBarrier,
builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
builder_->IntImm(type_int,
static_cast<int64_t>(spv::MemorySemanticsSequentiallyConsistentMask |
spv::MemorySemanticsWorkgroupMemoryMask)));

uint32_t vulkan_api_version = spirv_support_.vulkan_api_version;

int64_t sync_scope;
int64_t memory_semantics;
if ((sync == "warp") && (vulkan_api_version >= VK_API_VERSION_1_1)) {
sync_scope = spv::ScopeSubgroup;
memory_semantics =
spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsSubgroupMemoryMask;
} else if ((sync == "shared") || (sync == "warp")) {
sync_scope = spv::ScopeWorkgroup;
memory_semantics =
spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsWorkgroupMemoryMask;
} else {
LOG(FATAL) << "Do not support sync " << sync;
}

auto type_int = builder_->GetSType(DataType::Int(32));
builder_->MakeInst(spv::OpControlBarrier, builder_->IntImm(type_int, sync_scope),
builder_->IntImm(type_int, sync_scope),
builder_->IntImm(type_int, memory_semantics));
return value;
}

Expand Down
4 changes: 4 additions & 0 deletions src/target/spirv/spirv_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) {
ICHECK_EQ(target->kind->device_type, kDLVulkan)
<< "SPIRVSupport can only be checked for vulkan device type";

if (target->GetAttr<Integer>("vulkan_api_version")) {
vulkan_api_version = target->GetAttr<Integer>("vulkan_api_version").value();
}

if (target->GetAttr<Integer>("supported_subgroup_operations")) {
supported_subgroup_operations =
target->GetAttr<Integer>("supported_subgroup_operations").value();
Expand Down
14 changes: 14 additions & 0 deletions src/target/spirv/spirv_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define TVM_TARGET_SPIRV_SPIRV_SUPPORT_H_

#include <tvm/target/target.h>
#include <vulkan/vulkan_core.h>

namespace tvm {
namespace codegen {
Expand All @@ -37,6 +38,19 @@ struct SPIRVSupport {
*/
explicit SPIRVSupport(Target target);

/*! \brief The Vulkan API version supported by the device.
*
* Vulkan struct: VkPhysicalDeviceProperties
* Device property: apiVersion
*
* If VK_KHR_driver_properties is present, will also check the
* driver conformance version. If the version advertised does not
* pass the Vulkan conformance test, vulkan_api_version will be the
* latest Vulkan version that does pass the conformance test
* instead.
*/
uint32_t vulkan_api_version{VK_MAKE_VERSION(1, 0, 0)};

/*!
* \brief The supported subgroup operations
*
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
while (reduce_align > 1) {
reduce_align = reduce_align >> 1;
in_warp_seq.emplace_back(freduce(reduce_align));
seq.emplace_back(SyncThread("warp"));
in_warp_seq.emplace_back(SyncThread("warp"));
}
if (in_warp_seq.size() != 0) {
Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
Expand Down